Skip to main content

ctranslate2_src_build_support/
a.rs

1use std::{
2    env,
3    fs::read_dir,
4    path::{Path, PathBuf},
5};
6
7use crate::{
8    Os,
9    dnnl::build_dnnl,
10    download::download_helper,
11    export,
12    file_changes::watch_dir_recursively,
13    link_dynamic_libraries, link_libraries,
14    native::{build_native, cuda_root},
15    submodules,
16    windows_crt_patch::patch_cmake_runtime_flags,
17};
18
19pub fn link(
20    os: Os,
21    cuda: bool,
22    cudnn: bool,
23    cuda_dynamic_loading: bool,
24    openblas: bool,
25    dnnl: bool,
26    accelarate: bool,
27    openmp_comp: bool,
28    openmp_intel: bool,
29    cuda_root: Option<PathBuf>,
30    shared: bool,
31) {
32    if cuda && !shared {
33        if let Some(cuda) = cuda_root {
34            println!("cargo:rustc-link-search={}", cuda.join("lib").display());
35            println!("cargo:rustc-link-search={}", cuda.join("lib64").display());
36            println!("cargo:rustc-link-search={}", cuda.join("lib/x64").display());
37        }
38
39        println!("cargo:rustc-link-lib=static=cudart_static");
40        if cudnn {
41            println!("cargo:rustc-link-lib=cudnn");
42        }
43        if !cuda_dynamic_loading {
44            if os == Os::Win {
45                println!("cargo:rustc-link-lib=static=cublas");
46                println!("cargo:rustc-link-lib=static=cublasLt");
47            } else {
48                println!("cargo:rustc-link-lib=static=cublas_static");
49                println!("cargo:rustc-link-lib=static=cublasLt_static");
50                println!("cargo:rustc-link-lib=static=culibos");
51            }
52        }
53    }
54
55    if openblas && !shared {
56        println!("cargo:rustc-link-lib=static=openblas");
57    }
58    if accelarate {
59        println!("cargo:rustc-link-lib=framework=Accelerate");
60    }
61    if dnnl {
62        // build_dnnl(!shared);
63    }
64    if openmp_comp && !shared {
65        println!("cargo:rustc-link-lib=gomp");
66    } else if openmp_intel && !shared {
67        if os == Os::Win {
68            println!("cargo:rustc-link-lib=dylib=libiomp5md");
69        } else {
70            println!("cargo:rustc-link-lib=iomp5");
71        }
72    }
73}
74
75#[cfg(not(target_os = "windows"))]
76const PATH_SEPARATOR: char = ':';
77
78#[cfg(target_os = "windows")]
79const PATH_SEPARATOR: char = ';';
80
81fn add_search_paths(key: &str) {
82    println!("cargo:rerun-if-env-changed={}", key);
83    if let Ok(library_path) = env::var(key) {
84        library_path
85            .split(PATH_SEPARATOR)
86            .filter(|v| !v.is_empty())
87            .for_each(|v| {
88                println!("cargo:rustc-link-search={}", v);
89            });
90    }
91}
92
93fn get_download_link(
94    os: Os,
95    version: &str,
96    aarch64: bool,
97    shared: bool,
98    crt_dyn: bool,
99) -> Option<String> {
100    Some(format!(
101        "https://github.com/frederik-uni/ctranslate2-src/releases/download/v{version}/ctranslate2-{}{}-{}-{}.tar.gz",
102        if shared { "shared" } else { "static" },
103        if crt_dyn && os == Os::Win { "-crt" } else { "" },
104        match os {
105            Os::Win => "windows",
106            Os::Mac => "macos",
107            Os::Linux => "linux",
108            Os::Unknown => return None,
109        },
110        match aarch64 {
111            true => "arm64",
112            false => "x86_64",
113        }
114    ))
115}
116
117fn get_dir() -> PathBuf {
118    let out_dir = PathBuf::from(env::var("OUT_DIR").unwrap());
119    out_dir
120        .parent()
121        .unwrap()
122        .parent()
123        .unwrap()
124        .parent()
125        .unwrap()
126        .to_path_buf()
127}
128
129fn link_vendor(os: Os, aarch64: bool, shared: bool) {
130    match (os, aarch64) {
131        (Os::Win, false) => {
132            link(
133                os, true, true, true, false, true, false, false, true, None, shared,
134            );
135        }
136        (Os::Mac, true) => {
137            link(
138                os, false, false, false, false, false, true, false, false, None, shared,
139            );
140        }
141        (Os::Linux, true) => {
142            link(
143                os, false, false, false, true, false, false, true, false, None, shared,
144            );
145        }
146        (Os::Mac, false) => {
147            link(
148                os, false, false, false, false, true, false, false, true, None, shared,
149            );
150        }
151        (Os::Linux, false) => {
152            link(
153                os, true, true, true, false, false, false, true, false, None, shared,
154            );
155        }
156        _ => panic!("Unsupported platform"),
157    }
158}
159
160fn load_vendor(os: Os, aarch64: bool, shared: bool, crt_dynamic: bool) -> Option<PathBuf> {
161    let main_dir = get_dir();
162    let out_dir = main_dir.join("ctranslate2-vendor");
163
164    let dyn_dir = out_dir.join("dyn");
165    let url = get_download_link(os, "4.6.0", aarch64, shared, crt_dynamic)?;
166    download_helper(&url, &out_dir, true)?;
167
168    watch_dir_recursively(&dyn_dir);
169
170    let files = dyn_dir
171        .read_dir()
172        .map(|v| v.into_iter().filter_map(|v| v.ok()).collect::<Vec<_>>())
173        .unwrap_or_default()
174        .iter()
175        .map(|v| v.path())
176        .filter(|p| {
177            let ext = p
178                .extension()
179                .and_then(|v| v.to_str())
180                .unwrap_or_default()
181                .to_lowercase();
182            ext == "dll" || ext == "so" || ext == "dylib"
183        })
184        .collect::<Vec<_>>();
185    println!(
186        "cargo:warning=Required dylibs are in: {}",
187        main_dir.display()
188    );
189    for file in files {
190        println!("cargo:warning=- {}", file.display());
191        let tar = main_dir.join(file.file_name().unwrap_or_default());
192        std::fs::copy(&file, &tar).unwrap();
193        // Github actions has sometimes some issues with finding files. I hope that fixes it
194        println!("cargo:rerun-if-changed={}", tar.display());
195    }
196
197    println!("cargo:rustc-link-search=native={}", dyn_dir.display());
198    Some(out_dir.join("lib"))
199}
200
201pub fn main(
202    (
203        os,
204        aarch64,
205        cuda,
206        cudnn,
207        cuda_dynamic_loading,
208        mkl,
209        openblas,
210        ruy,
211        accelarate,
212        tensor_parallel,
213        dnnl,
214        openmp_comp,
215        openmp_intel,
216        msse4_1,
217        flash_attention,
218        cuda_small_binary,
219        shared,
220        vendor,
221        crt_dynamic,
222        export_vendor,
223        path,
224    ): (
225        Os,
226        bool,
227        bool,
228        bool,
229        bool,
230        bool,
231        bool,
232        bool,
233        bool,
234        bool,
235        bool,
236        bool,
237        bool,
238        bool,
239        bool,
240        bool,
241        bool,
242        bool,
243        bool,
244        bool,
245        Option<&Path>,
246    ),
247) -> PathBuf {
248    add_search_paths("LIBRARY_PATH");
249
250    println!("cargo:rerun-if-changed=build.rs");
251    println!("cargo:rerun-if-changed=src/sys");
252    println!("cargo:rerun-if-changed=include");
253    println!("cargo:rerun-if-changed=CTranslate2");
254
255    let mut found = None;
256
257    if vendor {
258        link_vendor(os, aarch64, shared);
259        found = load_vendor(os, aarch64, shared, crt_dynamic);
260    }
261    let (lib_path, include_path) = if let Some(found) = found {
262        (found.clone(), found.join("include"))
263    } else {
264        add_search_paths("CMAKE_LIBRARY_PATH");
265        link(
266            os,
267            cuda,
268            cudnn,
269            cuda_dynamic_loading,
270            openblas,
271            dnnl,
272            accelarate,
273            openmp_comp,
274            openmp_intel,
275            Some(cuda_root()).expect("CUDA_TOOLKIT_ROOT_DIR is not specified"),
276            shared,
277        );
278        let p = if let Some(path) = path {
279            path.to_path_buf()
280        } else {
281            let release =
282                std::env::var("CTRANSLATE2_RELEASE").unwrap_or_else(|_| "4.6.0".to_owned());
283            let url = format!(
284                "https://github.com/OpenNMT/CTranslate2/archive/refs/tags/v{release}.tar.gz"
285            );
286
287            let p = format!("CTranslate2-{release}");
288            let p = get_dir().join(Path::new(&p));
289            let d = &get_dir();
290            if !p.exists() {
291                download_helper(&url, d, false).unwrap();
292            }
293            for module in submodules::get_submodules_helper(d, &release) {
294                if !module.exists()
295                    || read_dir(module)
296                        .unwrap()
297                        .into_iter()
298                        .filter_map(|v| v.ok())
299                        .count()
300                        < 2
301                {
302                    std::thread::sleep(std::time::Duration::from_millis(200));
303                }
304            }
305            if !p.exists() {
306                panic!("CTranslate2-{release} not found locally")
307            }
308            if os == Os::Win {
309                patch_cmake_runtime_flags(p.join("CMakeLists.txt"), crt_dynamic).unwrap();
310            }
311            p
312        };
313
314        (
315            build_native(
316                &p,
317                os,
318                cuda,
319                cudnn,
320                cuda_dynamic_loading,
321                aarch64,
322                mkl,
323                openblas,
324                ruy,
325                accelarate,
326                tensor_parallel,
327                msse4_1,
328                dnnl,
329                openmp_comp,
330                openmp_intel,
331                flash_attention,
332                cuda_small_binary,
333                shared,
334            ),
335            p.join("include"),
336        )
337    };
338
339    let modules = link_libraries(&lib_path);
340    let modules2 = link_dynamic_libraries(&lib_path);
341
342    if export_vendor {
343        export(&lib_path, &modules, &modules2);
344    }
345    include_path
346}