1use crate::get_cache_dir;
2use ort::execution_providers::ExecutionProviderDispatch;
3use std::path::PathBuf;
4
5pub trait HasMaxLength {
6 const MAX_LENGTH: usize;
7}
8
9#[derive(Debug, Clone)]
10#[non_exhaustive]
11pub struct InitOptionsWithLength<M> {
12 pub model_name: M,
13 pub execution_providers: Vec<ExecutionProviderDispatch>,
14 pub cache_dir: PathBuf,
15 pub show_download_progress: bool,
16 pub max_length: usize,
17}
18
19#[derive(Debug, Clone)]
20#[non_exhaustive]
21pub struct InitOptions<M> {
22 pub model_name: M,
23 pub execution_providers: Vec<ExecutionProviderDispatch>,
24 pub cache_dir: PathBuf,
25 pub show_download_progress: bool,
26}
27
28impl<M: Default + HasMaxLength> Default for InitOptionsWithLength<M> {
29 fn default() -> Self {
30 Self {
31 model_name: M::default(),
32 execution_providers: Default::default(),
33 cache_dir: get_cache_dir().into(),
34 show_download_progress: true,
35 max_length: M::MAX_LENGTH,
36 }
37 }
38}
39
40impl<M: Default> Default for InitOptions<M> {
41 fn default() -> Self {
42 Self {
43 model_name: M::default(),
44 execution_providers: Default::default(),
45 cache_dir: get_cache_dir().into(),
46 show_download_progress: true,
47 }
48 }
49}
50
51impl<M: Default + HasMaxLength> InitOptionsWithLength<M> {
52 pub fn new(model_name: M) -> Self {
54 Self {
55 model_name,
56 ..Default::default()
57 }
58 }
59
60 pub fn with_max_length(mut self, max_length: usize) -> Self {
62 self.max_length = max_length;
63 self
64 }
65
66 pub fn with_cache_dir(mut self, cache_dir: PathBuf) -> Self {
68 self.cache_dir = cache_dir;
69 self
70 }
71
72 pub fn with_execution_providers(
74 mut self,
75 execution_providers: Vec<ExecutionProviderDispatch>,
76 ) -> Self {
77 self.execution_providers = execution_providers;
78 self
79 }
80
81 pub fn with_show_download_progress(mut self, show_download_progress: bool) -> Self {
83 self.show_download_progress = show_download_progress;
84 self
85 }
86}
87
88impl<M: Default> InitOptions<M> {
89 pub fn new(model_name: M) -> Self {
91 Self {
92 model_name,
93 ..Default::default()
94 }
95 }
96
97 pub fn with_cache_dir(mut self, cache_dir: PathBuf) -> Self {
99 self.cache_dir = cache_dir;
100 self
101 }
102
103 pub fn with_execution_providers(
105 mut self,
106 execution_providers: Vec<ExecutionProviderDispatch>,
107 ) -> Self {
108 self.execution_providers = execution_providers;
109 self
110 }
111
112 pub fn with_show_download_progress(mut self, show_download_progress: bool) -> Self {
114 self.show_download_progress = show_download_progress;
115 self
116 }
117}