langchain_rust/language_models/
options.rs

1use futures::Future;
2use std::{pin::Pin, sync::Arc};
3use tokio::sync::Mutex;
4
5use crate::schemas::{FunctionCallBehavior, FunctionDefinition};
6
7#[derive(Clone)]
8pub struct CallOptions {
9    pub candidate_count: Option<usize>,
10    pub max_tokens: Option<u32>,
11    pub temperature: Option<f32>,
12    pub stop_words: Option<Vec<String>>,
13    pub streaming_func: Option<
14        Arc<
15            Mutex<dyn FnMut(String) -> Pin<Box<dyn Future<Output = Result<(), ()>> + Send>> + Send>,
16        >,
17    >,
18    pub top_k: Option<usize>,
19    pub top_p: Option<f32>,
20    pub seed: Option<usize>,
21    pub min_length: Option<usize>,
22    pub max_length: Option<usize>,
23    pub n: Option<usize>,
24    pub repetition_penalty: Option<f32>,
25    pub frequency_penalty: Option<f32>,
26    pub presence_penalty: Option<f32>,
27    pub functions: Option<Vec<FunctionDefinition>>,
28    pub function_call_behavior: Option<FunctionCallBehavior>,
29    pub stream_usage: Option<bool>,
30}
31
32impl Default for CallOptions {
33    fn default() -> Self {
34        CallOptions::new()
35    }
36}
37impl CallOptions {
38    pub fn new() -> Self {
39        CallOptions {
40            candidate_count: None,
41            max_tokens: None,
42            temperature: None,
43            stop_words: None,
44            streaming_func: None,
45            top_k: None,
46            top_p: None,
47            seed: None,
48            min_length: None,
49            max_length: None,
50            n: None,
51            repetition_penalty: None,
52            frequency_penalty: None,
53            presence_penalty: None,
54            functions: None,
55            function_call_behavior: None,
56            stream_usage: None,
57        }
58    }
59
60    // Refactored "with" functions as methods of CallOptions
61    pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
62        self.max_tokens = Some(max_tokens);
63        self
64    }
65
66    pub fn with_candidate_count(mut self, candidate_count: usize) -> Self {
67        self.candidate_count = Some(candidate_count);
68        self
69    }
70
71    pub fn with_temperature(mut self, temperature: f32) -> Self {
72        self.temperature = Some(temperature);
73        self
74    }
75
76    pub fn with_stop_words(mut self, stop_words: Vec<String>) -> Self {
77        self.stop_words = Some(stop_words);
78        self
79    }
80
81    //TODO:Check if this should be a &str instead of a String
82    pub fn with_streaming_func<F, Fut>(mut self, mut func: F) -> Self
83    where
84        F: FnMut(String) -> Fut + Send + 'static,
85        Fut: Future<Output = Result<(), ()>> + Send + 'static,
86    {
87        let func = Arc::new(Mutex::new(
88            move |s: String| -> Pin<Box<dyn Future<Output = Result<(), ()>> + Send>> {
89                Box::pin(func(s))
90            },
91        ));
92
93        self.streaming_func = Some(func);
94        self
95    }
96
97    pub fn with_top_k(mut self, top_k: usize) -> Self {
98        self.top_k = Some(top_k);
99        self
100    }
101
102    pub fn with_top_p(mut self, top_p: f32) -> Self {
103        self.top_p = Some(top_p);
104        self
105    }
106
107    pub fn with_seed(mut self, seed: usize) -> Self {
108        self.seed = Some(seed);
109        self
110    }
111
112    pub fn with_min_length(mut self, min_length: usize) -> Self {
113        self.min_length = Some(min_length);
114        self
115    }
116
117    pub fn with_max_length(mut self, max_length: usize) -> Self {
118        self.max_length = Some(max_length);
119        self
120    }
121
122    pub fn with_n(mut self, n: usize) -> Self {
123        self.n = Some(n);
124        self
125    }
126
127    pub fn with_repetition_penalty(mut self, repetition_penalty: f32) -> Self {
128        self.repetition_penalty = Some(repetition_penalty);
129        self
130    }
131
132    pub fn with_frequency_penalty(mut self, frequency_penalty: f32) -> Self {
133        self.frequency_penalty = Some(frequency_penalty);
134        self
135    }
136
137    pub fn with_presence_penalty(mut self, presence_penalty: f32) -> Self {
138        self.presence_penalty = Some(presence_penalty);
139        self
140    }
141
142    pub fn with_functions(mut self, functions: Vec<FunctionDefinition>) -> Self {
143        self.functions = Some(functions);
144        self
145    }
146
147    pub fn with_function_call_behavior(mut self, behavior: FunctionCallBehavior) -> Self {
148        self.function_call_behavior = Some(behavior);
149        self
150    }
151
152    pub fn with_stream_usage(mut self, stream_usage: bool) -> Self {
153        self.stream_usage = Some(stream_usage);
154        self
155    }
156
157    pub fn merge_options(&mut self, incoming_options: CallOptions) {
158        // For simple scalar types wrapped in Option, prefer incoming option if it is Some
159        self.candidate_count = incoming_options.candidate_count.or(self.candidate_count);
160        self.max_tokens = incoming_options.max_tokens.or(self.max_tokens);
161        self.temperature = incoming_options.temperature.or(self.temperature);
162        self.top_k = incoming_options.top_k.or(self.top_k);
163        self.top_p = incoming_options.top_p.or(self.top_p);
164        self.seed = incoming_options.seed.or(self.seed);
165        self.min_length = incoming_options.min_length.or(self.min_length);
166        self.max_length = incoming_options.max_length.or(self.max_length);
167        self.n = incoming_options.n.or(self.n);
168        self.repetition_penalty = incoming_options
169            .repetition_penalty
170            .or(self.repetition_penalty);
171        self.frequency_penalty = incoming_options
172            .frequency_penalty
173            .or(self.frequency_penalty);
174        self.presence_penalty = incoming_options.presence_penalty.or(self.presence_penalty);
175        self.function_call_behavior = incoming_options
176            .function_call_behavior
177            .or(self.function_call_behavior.clone());
178        self.stream_usage = incoming_options.stream_usage.or(self.stream_usage);
179
180        // For `Vec<String>`, merge if both are Some; prefer incoming if only incoming is Some
181        if let Some(mut new_stop_words) = incoming_options.stop_words {
182            if let Some(existing_stop_words) = &mut self.stop_words {
183                existing_stop_words.append(&mut new_stop_words);
184            } else {
185                self.stop_words = Some(new_stop_words);
186            }
187        }
188
189        // For `Vec<FunctionDefinition>`, similar logic to `Vec<String>`
190        if let Some(mut incoming_functions) = incoming_options.functions {
191            if let Some(existing_functions) = &mut self.functions {
192                existing_functions.append(&mut incoming_functions);
193            } else {
194                self.functions = Some(incoming_functions);
195            }
196        }
197
198        // `streaming_func` requires a judgment call on how you want to handle merging.
199        // Here, the incoming option simply replaces the existing one if it's Some.
200        self.streaming_func = incoming_options
201            .streaming_func
202            .or_else(|| self.streaming_func.clone());
203    }
204}