1use std::{
2 env,
3 fs::read_dir,
4 path::{Path, PathBuf},
5};
6
7use crate::{
8 Os,
9 dnnl::build_dnnl,
10 download::download_helper,
11 export,
12 file_changes::watch_dir_recursively,
13 link_dynamic_libraries, link_libraries,
14 native::{build_native, cuda_root},
15 submodules,
16 windows_crt_patch::patch_cmake_runtime_flags,
17};
18
19pub fn link(
20 os: Os,
21 cuda: bool,
22 cudnn: bool,
23 cuda_dynamic_loading: bool,
24 openblas: bool,
25 dnnl: bool,
26 accelarate: bool,
27 openmp_comp: bool,
28 openmp_intel: bool,
29 cuda_root: Option<PathBuf>,
30 shared: bool,
31) {
32 if cuda && !shared {
33 if let Some(cuda) = cuda_root {
34 println!("cargo:rustc-link-search={}", cuda.join("lib").display());
35 println!("cargo:rustc-link-search={}", cuda.join("lib64").display());
36 println!("cargo:rustc-link-search={}", cuda.join("lib/x64").display());
37 }
38
39 println!("cargo:rustc-link-lib=static=cudart_static");
40 if cudnn {
41 println!("cargo:rustc-link-lib=cudnn");
42 }
43 if !cuda_dynamic_loading {
44 if os == Os::Win {
45 println!("cargo:rustc-link-lib=static=cublas");
46 println!("cargo:rustc-link-lib=static=cublasLt");
47 } else {
48 println!("cargo:rustc-link-lib=static=cublas_static");
49 println!("cargo:rustc-link-lib=static=cublasLt_static");
50 println!("cargo:rustc-link-lib=static=culibos");
51 }
52 }
53 }
54
55 if openblas && !shared {
56 println!("cargo:rustc-link-lib=static=openblas");
57 }
58 if accelarate {
59 println!("cargo:rustc-link-lib=framework=Accelerate");
60 }
61 if dnnl {
62 }
64 if openmp_comp && !shared {
65 println!("cargo:rustc-link-lib=gomp");
66 } else if openmp_intel && !shared {
67 if os == Os::Win {
68 println!("cargo:rustc-link-lib=dylib=libiomp5md");
69 } else {
70 println!("cargo:rustc-link-lib=iomp5");
71 }
72 }
73}
74
75#[cfg(not(target_os = "windows"))]
76const PATH_SEPARATOR: char = ':';
77
78#[cfg(target_os = "windows")]
79const PATH_SEPARATOR: char = ';';
80
81fn add_search_paths(key: &str) {
82 println!("cargo:rerun-if-env-changed={}", key);
83 if let Ok(library_path) = env::var(key) {
84 library_path
85 .split(PATH_SEPARATOR)
86 .filter(|v| !v.is_empty())
87 .for_each(|v| {
88 println!("cargo:rustc-link-search={}", v);
89 });
90 }
91}
92
93fn get_download_link(
94 os: Os,
95 version: &str,
96 aarch64: bool,
97 shared: bool,
98 crt_dyn: bool,
99) -> Option<String> {
100 Some(format!(
101 "https://github.com/frederik-uni/ctranslate2-src/releases/download/v{version}/ctranslate2-{}{}-{}-{}.tar.gz",
102 if shared { "shared" } else { "static" },
103 if crt_dyn && os == Os::Win { "-crt" } else { "" },
104 match os {
105 Os::Win => "windows",
106 Os::Mac => "macos",
107 Os::Linux => "linux",
108 Os::Unknown => return None,
109 },
110 match aarch64 {
111 true => "arm64",
112 false => "x86_64",
113 }
114 ))
115}
116
117fn get_dir() -> PathBuf {
118 let out_dir = PathBuf::from(env::var("OUT_DIR").unwrap());
119 out_dir
120 .parent()
121 .unwrap()
122 .parent()
123 .unwrap()
124 .parent()
125 .unwrap()
126 .to_path_buf()
127}
128
129fn link_vendor(os: Os, aarch64: bool, shared: bool) {
130 match (os, aarch64) {
131 (Os::Win, false) => {
132 link(
133 os, true, true, true, false, true, false, false, true, None, shared,
134 );
135 }
136 (Os::Mac, true) => {
137 link(
138 os, false, false, false, false, false, true, false, false, None, shared,
139 );
140 }
141 (Os::Linux, true) => {
142 link(
143 os, false, false, false, true, false, false, true, false, None, shared,
144 );
145 }
146 (Os::Mac, false) => {
147 link(
148 os, false, false, false, false, true, false, false, true, None, shared,
149 );
150 }
151 (Os::Linux, false) => {
152 link(
153 os, true, true, true, false, false, false, true, false, None, shared,
154 );
155 }
156 _ => panic!("Unsupported platform"),
157 }
158}
159
160fn load_vendor(os: Os, aarch64: bool, shared: bool, crt_dynamic: bool) -> Option<PathBuf> {
161 let main_dir = get_dir();
162 let out_dir = main_dir.join("ctranslate2-vendor");
163
164 let dyn_dir = out_dir.join("dyn");
165 let url = get_download_link(os, "4.6.0", aarch64, shared, crt_dynamic)?;
166 download_helper(&url, &out_dir, true)?;
167
168 watch_dir_recursively(&dyn_dir);
169
170 let files = dyn_dir
171 .read_dir()
172 .map(|v| v.into_iter().filter_map(|v| v.ok()).collect::<Vec<_>>())
173 .unwrap_or_default()
174 .iter()
175 .map(|v| v.path())
176 .filter(|p| {
177 let ext = p
178 .extension()
179 .and_then(|v| v.to_str())
180 .unwrap_or_default()
181 .to_lowercase();
182 ext == "dll" || ext == "so" || ext == "dylib"
183 })
184 .collect::<Vec<_>>();
185 println!(
186 "cargo:warning=Required dylibs are in: {}",
187 main_dir.display()
188 );
189 for file in files {
190 println!("cargo:warning=- {}", file.display());
191 let tar = main_dir.join(file.file_name().unwrap_or_default());
192 std::fs::copy(&file, &tar).unwrap();
193 println!("cargo:rerun-if-changed={}", tar.display());
195 }
196
197 println!("cargo:rustc-link-search=native={}", dyn_dir.display());
198 Some(out_dir.join("lib"))
199}
200
201pub fn main(
202 (
203 os,
204 aarch64,
205 cuda,
206 cudnn,
207 cuda_dynamic_loading,
208 mkl,
209 openblas,
210 ruy,
211 accelarate,
212 tensor_parallel,
213 dnnl,
214 openmp_comp,
215 openmp_intel,
216 msse4_1,
217 flash_attention,
218 cuda_small_binary,
219 shared,
220 vendor,
221 crt_dynamic,
222 export_vendor,
223 path,
224 ): (
225 Os,
226 bool,
227 bool,
228 bool,
229 bool,
230 bool,
231 bool,
232 bool,
233 bool,
234 bool,
235 bool,
236 bool,
237 bool,
238 bool,
239 bool,
240 bool,
241 bool,
242 bool,
243 bool,
244 bool,
245 Option<&Path>,
246 ),
247) -> PathBuf {
248 add_search_paths("LIBRARY_PATH");
249
250 println!("cargo:rerun-if-changed=build.rs");
251 println!("cargo:rerun-if-changed=src/sys");
252 println!("cargo:rerun-if-changed=include");
253 println!("cargo:rerun-if-changed=CTranslate2");
254
255 let mut found = None;
256
257 if vendor {
258 link_vendor(os, aarch64, shared);
259 found = load_vendor(os, aarch64, shared, crt_dynamic);
260 }
261 let (lib_path, include_path) = if let Some(found) = found {
262 (found.clone(), found.join("include"))
263 } else {
264 add_search_paths("CMAKE_LIBRARY_PATH");
265 link(
266 os,
267 cuda,
268 cudnn,
269 cuda_dynamic_loading,
270 openblas,
271 dnnl,
272 accelarate,
273 openmp_comp,
274 openmp_intel,
275 Some(cuda_root()).expect("CUDA_TOOLKIT_ROOT_DIR is not specified"),
276 shared,
277 );
278 let p = if let Some(path) = path {
279 path.to_path_buf()
280 } else {
281 let release =
282 std::env::var("CTRANSLATE2_RELEASE").unwrap_or_else(|_| "4.6.0".to_owned());
283 let url = format!(
284 "https://github.com/OpenNMT/CTranslate2/archive/refs/tags/v{release}.tar.gz"
285 );
286
287 let p = format!("CTranslate2-{release}");
288 let p = get_dir().join(Path::new(&p));
289 let d = &get_dir();
290 if !p.exists() {
291 download_helper(&url, d, false).unwrap();
292 }
293 for module in submodules::get_submodules_helper(d, &release) {
294 if !module.exists()
295 || read_dir(module)
296 .unwrap()
297 .into_iter()
298 .filter_map(|v| v.ok())
299 .count()
300 < 2
301 {
302 std::thread::sleep(std::time::Duration::from_millis(200));
303 }
304 }
305 if !p.exists() {
306 panic!("CTranslate2-{release} not found locally")
307 }
308 if os == Os::Win {
309 patch_cmake_runtime_flags(p.join("CMakeLists.txt"), crt_dynamic).unwrap();
310 }
311 p
312 };
313
314 (
315 build_native(
316 &p,
317 os,
318 cuda,
319 cudnn,
320 cuda_dynamic_loading,
321 aarch64,
322 mkl,
323 openblas,
324 ruy,
325 accelarate,
326 tensor_parallel,
327 msse4_1,
328 dnnl,
329 openmp_comp,
330 openmp_intel,
331 flash_attention,
332 cuda_small_binary,
333 shared,
334 ),
335 p.join("include"),
336 )
337 };
338
339 let modules = link_libraries(&lib_path);
340 let modules2 = link_dynamic_libraries(&lib_path);
341
342 if export_vendor {
343 export(&lib_path, &modules, &modules2);
344 }
345 include_path
346}