1use std::{
2 env,
3 fs::read_dir,
4 path::{Path, PathBuf},
5};
6
7use crate::{
8 Os,
9 dnnl::build_dnnl,
10 download::download_helper,
11 export,
12 file_changes::watch_dir_recursively,
13 link_dynamic_libraries, link_libraries,
14 native::{build_native, cuda_root},
15 submodules,
16 windows_crt_patch::patch_cmake_runtime_flags,
17};
18
19pub fn link(
20 os: Os,
21 cuda: bool,
22 cudnn: bool,
23 cuda_dynamic_loading: bool,
24 openblas: bool,
25 dnnl: bool,
26 accelarate: bool,
27 openmp_comp: bool,
28 openmp_intel: bool,
29 cuda_root: Option<PathBuf>,
30 shared: bool,
31) {
32 if cuda && !shared {
33 if let Some(cuda) = cuda_root {
34 println!("cargo:rustc-link-search={}", cuda.join("lib").display());
35 println!("cargo:rustc-link-search={}", cuda.join("lib64").display());
36 println!("cargo:rustc-link-search={}", cuda.join("lib/x64").display());
37 }
38
39 println!("cargo:rustc-link-lib=static=cudart_static");
40 if cudnn {
41 println!("cargo:rustc-link-lib=cudnn");
42 }
43 if !cuda_dynamic_loading {
44 if os == Os::Win {
45 println!("cargo:rustc-link-lib=static=cublas");
46 println!("cargo:rustc-link-lib=static=cublasLt");
47 } else {
48 println!("cargo:rustc-link-lib=static=cublas_static");
49 println!("cargo:rustc-link-lib=static=cublasLt_static");
50 println!("cargo:rustc-link-lib=static=culibos");
51 }
52 }
53 }
54
55 if openblas && !shared {
56 println!("cargo:rustc-link-lib=static=openblas");
57 }
58 if accelarate {
59 println!("cargo:rustc-link-lib=framework=Accelerate");
60 }
61 if dnnl {
62 build_dnnl(!shared);
63 }
64 if openmp_comp && !shared {
65 println!("cargo:rustc-link-lib=gomp");
66 } else if openmp_intel && !shared {
67 if os == Os::Win {
68 println!("cargo:rustc-link-lib=dylib=libiomp5md");
69 } else {
70 println!("cargo:rustc-link-lib=iomp5");
71 }
72 }
73}
74
75#[cfg(not(target_os = "windows"))]
76const PATH_SEPARATOR: char = ':';
77
78#[cfg(target_os = "windows")]
79const PATH_SEPARATOR: char = ';';
80
81fn add_search_paths(key: &str) {
82 println!("cargo:rerun-if-env-changed={}", key);
83 if let Ok(library_path) = env::var(key) {
84 library_path
85 .split(PATH_SEPARATOR)
86 .filter(|v| !v.is_empty())
87 .for_each(|v| {
88 println!("cargo:rustc-link-search={}", v);
89 });
90 }
91}
92
93fn get_download_link(
94 os: Os,
95 version: &str,
96 aarch64: bool,
97 shared: bool,
98 crt_dyn: bool,
99) -> Option<String> {
100 Some(format!(
101 "https://github.com/frederik-uni/ctranslate2-src/releases/download/v{version}/ctranslate2-{}{}-{}-{}.tar.gz",
102 if shared { "shared" } else { "static" },
103 if crt_dyn && os == Os::Win { "-crt" } else { "" },
104 match os {
105 Os::Win => "windows",
106 Os::Mac => "macos",
107 Os::Linux => "linux",
108 Os::Unknown => return None,
109 },
110 match aarch64 {
111 true => "arm64",
112 false => "x86_64",
113 }
114 ))
115}
116
117fn get_dir() -> PathBuf {
118 let out_dir = PathBuf::from(env::var("OUT_DIR").unwrap());
119 out_dir
120 .parent()
121 .unwrap()
122 .parent()
123 .unwrap()
124 .parent()
125 .unwrap()
126 .to_path_buf()
127}
128
129fn link_vendor(os: Os, aarch64: bool, shared: bool) {
130 match (os, aarch64) {
131 (Os::Win, false) => {
132 link(
133 os, true, true, true, false, true, false, false, true, None, shared,
134 );
135 }
136 (Os::Mac, true) => {
137 link(
138 os, false, false, false, false, false, true, false, false, None, shared,
139 );
140 }
141 (Os::Linux, true) => {
142 link(
143 os, false, false, false, true, false, false, true, false, None, shared,
144 );
145 }
146 (Os::Mac, false) => {
147 link(
148 os, false, false, false, false, true, false, false, true, None, shared,
149 );
150 }
151 (Os::Linux, false) => {
152 link(
153 os, true, true, true, false, false, false, true, false, None, shared,
154 );
155 }
156 _ => panic!("Unsupported platform"),
157 }
158}
159
160fn load_vendor(os: Os, aarch64: bool, shared: bool, crt_dynamic: bool) -> Option<PathBuf> {
161 let main_dir = get_dir();
162 let out_dir = main_dir.join("ctranslate2-vendor");
163
164 let dyn_dir = out_dir.join("dyn");
165 let url = get_download_link(os, "4.6.0", aarch64, shared, crt_dynamic)?;
166 download_helper(&url, &out_dir, true)?;
167
168 watch_dir_recursively(&dyn_dir);
169
170 let files = dyn_dir
171 .read_dir()
172 .map(|v| v.into_iter().filter_map(|v| v.ok()).collect::<Vec<_>>())
173 .unwrap_or_default()
174 .iter()
175 .map(|v| v.path())
176 .filter(|p| {
177 let ext = p
178 .extension()
179 .and_then(|v| v.to_str())
180 .unwrap_or_default()
181 .to_lowercase();
182 ext == "dll" || ext == "so" || ext == "dylib"
183 })
184 .collect::<Vec<_>>();
185 println!(
186 "cargo:warning=Required dylibs are in: {}",
187 main_dir.display()
188 );
189 for file in files {
190 println!("cargo:warning=- {}", file.display());
191 let tar = main_dir.join(file.file_name().unwrap_or_default());
192 std::fs::copy(&file, &tar).unwrap();
193 println!("cargo:rerun-if-changed={}", tar.display());
195 }
196
197 println!("cargo:rustc-link-search=native={}", dyn_dir.display());
198 Some(out_dir.join("lib"))
199}
200
201pub fn main(
202 (
203 os,
204 aarch64,
205 cuda,
206 cudnn,
207 cuda_dynamic_loading,
208 mkl,
209 openblas,
210 ruy,
211 accelarate,
212 tensor_parallel,
213 dnnl,
214 openmp_comp,
215 openmp_intel,
216 msse4_1,
217 flash_attention,
218 cuda_small_binary,
219 shared,
220 vendor,
221 crt_dynamic,
222 export_vendor,
223 ): (
224 Os,
225 bool,
226 bool,
227 bool,
228 bool,
229 bool,
230 bool,
231 bool,
232 bool,
233 bool,
234 bool,
235 bool,
236 bool,
237 bool,
238 bool,
239 bool,
240 bool,
241 bool,
242 bool,
243 bool,
244 ),
245) -> PathBuf {
246 add_search_paths("LIBRARY_PATH");
247
248 println!("cargo:rerun-if-changed=build.rs");
249 println!("cargo:rerun-if-changed=src/sys");
250 println!("cargo:rerun-if-changed=include");
251 println!("cargo:rerun-if-changed=CTranslate2");
252
253 let mut found = None;
254
255 if vendor {
256 link_vendor(os, aarch64, shared);
257 found = load_vendor(os, aarch64, shared, crt_dynamic);
258 }
259 let (lib_path, include_path) = if let Some(found) = found {
260 (found.clone(), found.join("include"))
261 } else {
262 add_search_paths("CMAKE_LIBRARY_PATH");
263 link(
264 os,
265 cuda,
266 cudnn,
267 cuda_dynamic_loading,
268 openblas,
269 dnnl,
270 accelarate,
271 openmp_comp,
272 openmp_intel,
273 Some(cuda_root()).expect("CUDA_TOOLKIT_ROOT_DIR is not specified"),
274 shared,
275 );
276 let release = std::env::var("CTRANSLATE2_RELEASE").unwrap_or_else(|_| "4.6.0".to_owned());
277 let url =
278 format!("https://github.com/OpenNMT/CTranslate2/archive/refs/tags/v{release}.tar.gz");
279
280 let p = format!("CTranslate2-{release}");
281 let p = get_dir().join(Path::new(&p));
282 let d = &get_dir();
283 if !p.exists() {
284 download_helper(&url, d, false).unwrap();
285 }
286 for module in submodules::get_submodules_helper(d, &release) {
287 if !module.exists()
288 || read_dir(module)
289 .unwrap()
290 .into_iter()
291 .filter_map(|v| v.ok())
292 .count()
293 < 2
294 {
295 std::thread::sleep(std::time::Duration::from_millis(200));
296 }
297 }
298 if !p.exists() {
299 panic!("CTranslate2-{release} not found locally")
300 }
301 if os == Os::Win {
302 patch_cmake_runtime_flags(p.join("CMakeLists.txt"), crt_dynamic).unwrap();
303 }
304 (
305 build_native(
306 &p,
307 os,
308 cuda,
309 cudnn,
310 cuda_dynamic_loading,
311 aarch64,
312 mkl,
313 openblas,
314 ruy,
315 accelarate,
316 tensor_parallel,
317 msse4_1,
318 dnnl,
319 openmp_comp,
320 openmp_intel,
321 flash_attention,
322 cuda_small_binary,
323 shared,
324 ),
325 p.join("include"),
326 )
327 };
328
329 let modules = link_libraries(&lib_path);
330 let modules2 = link_dynamic_libraries(&lib_path);
331
332 if export_vendor {
333 export(&lib_path, &modules, &modules2);
334 }
335 include_path
336}