Skip to main content

fastembed/
init.rs

1use crate::get_cache_dir;
2use ort::execution_providers::ExecutionProviderDispatch;
3use std::path::PathBuf;
4
5pub trait HasMaxLength {
6    const MAX_LENGTH: usize;
7}
8
9#[derive(Debug, Clone)]
10#[non_exhaustive]
11pub struct InitOptionsWithLength<M> {
12    pub model_name: M,
13    pub execution_providers: Vec<ExecutionProviderDispatch>,
14    pub cache_dir: PathBuf,
15    pub show_download_progress: bool,
16    pub max_length: usize,
17}
18
19#[derive(Debug, Clone)]
20#[non_exhaustive]
21pub struct InitOptions<M> {
22    pub model_name: M,
23    pub execution_providers: Vec<ExecutionProviderDispatch>,
24    pub cache_dir: PathBuf,
25    pub show_download_progress: bool,
26}
27
28impl<M: Default + HasMaxLength> Default for InitOptionsWithLength<M> {
29    fn default() -> Self {
30        Self {
31            model_name: M::default(),
32            execution_providers: Default::default(),
33            cache_dir: get_cache_dir().into(),
34            show_download_progress: true,
35            max_length: M::MAX_LENGTH,
36        }
37    }
38}
39
40impl<M: Default> Default for InitOptions<M> {
41    fn default() -> Self {
42        Self {
43            model_name: M::default(),
44            execution_providers: Default::default(),
45            cache_dir: get_cache_dir().into(),
46            show_download_progress: true,
47        }
48    }
49}
50
51impl<M: Default + HasMaxLength> InitOptionsWithLength<M> {
52    /// Create a new InitOptionsWithLength with the given model name
53    pub fn new(model_name: M) -> Self {
54        Self {
55            model_name,
56            ..Default::default()
57        }
58    }
59
60    /// Set the maximum length
61    pub fn with_max_length(mut self, max_length: usize) -> Self {
62        self.max_length = max_length;
63        self
64    }
65
66    /// Set the cache directory for the model file
67    pub fn with_cache_dir(mut self, cache_dir: PathBuf) -> Self {
68        self.cache_dir = cache_dir;
69        self
70    }
71
72    /// Set the execution providers for the model
73    pub fn with_execution_providers(
74        mut self,
75        execution_providers: Vec<ExecutionProviderDispatch>,
76    ) -> Self {
77        self.execution_providers = execution_providers;
78        self
79    }
80
81    /// Set whether to show download progress
82    pub fn with_show_download_progress(mut self, show_download_progress: bool) -> Self {
83        self.show_download_progress = show_download_progress;
84        self
85    }
86}
87
88impl<M: Default> InitOptions<M> {
89    /// Create a new InitOptions with the given model name
90    pub fn new(model_name: M) -> Self {
91        Self {
92            model_name,
93            ..Default::default()
94        }
95    }
96
97    /// Set the cache directory for the model file
98    pub fn with_cache_dir(mut self, cache_dir: PathBuf) -> Self {
99        self.cache_dir = cache_dir;
100        self
101    }
102
103    /// Set the execution providers for the model
104    pub fn with_execution_providers(
105        mut self,
106        execution_providers: Vec<ExecutionProviderDispatch>,
107    ) -> Self {
108        self.execution_providers = execution_providers;
109        self
110    }
111
112    /// Set whether to show download progress
113    pub fn with_show_download_progress(mut self, show_download_progress: bool) -> Self {
114        self.show_download_progress = show_download_progress;
115        self
116    }
117}