langchain_rust/language_models/
options.rs1use futures::Future;
2use std::{pin::Pin, sync::Arc};
3use tokio::sync::Mutex;
4
5use crate::schemas::{FunctionCallBehavior, FunctionDefinition};
6
7#[derive(Clone)]
8pub struct CallOptions {
9 pub candidate_count: Option<usize>,
10 pub max_tokens: Option<u32>,
11 pub temperature: Option<f32>,
12 pub stop_words: Option<Vec<String>>,
13 pub streaming_func: Option<
14 Arc<
15 Mutex<dyn FnMut(String) -> Pin<Box<dyn Future<Output = Result<(), ()>> + Send>> + Send>,
16 >,
17 >,
18 pub top_k: Option<usize>,
19 pub top_p: Option<f32>,
20 pub seed: Option<usize>,
21 pub min_length: Option<usize>,
22 pub max_length: Option<usize>,
23 pub n: Option<usize>,
24 pub repetition_penalty: Option<f32>,
25 pub frequency_penalty: Option<f32>,
26 pub presence_penalty: Option<f32>,
27 pub functions: Option<Vec<FunctionDefinition>>,
28 pub function_call_behavior: Option<FunctionCallBehavior>,
29 pub stream_usage: Option<bool>,
30}
31
32impl Default for CallOptions {
33 fn default() -> Self {
34 CallOptions::new()
35 }
36}
37impl CallOptions {
38 pub fn new() -> Self {
39 CallOptions {
40 candidate_count: None,
41 max_tokens: None,
42 temperature: None,
43 stop_words: None,
44 streaming_func: None,
45 top_k: None,
46 top_p: None,
47 seed: None,
48 min_length: None,
49 max_length: None,
50 n: None,
51 repetition_penalty: None,
52 frequency_penalty: None,
53 presence_penalty: None,
54 functions: None,
55 function_call_behavior: None,
56 stream_usage: None,
57 }
58 }
59
60 pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
62 self.max_tokens = Some(max_tokens);
63 self
64 }
65
66 pub fn with_candidate_count(mut self, candidate_count: usize) -> Self {
67 self.candidate_count = Some(candidate_count);
68 self
69 }
70
71 pub fn with_temperature(mut self, temperature: f32) -> Self {
72 self.temperature = Some(temperature);
73 self
74 }
75
76 pub fn with_stop_words(mut self, stop_words: Vec<String>) -> Self {
77 self.stop_words = Some(stop_words);
78 self
79 }
80
81 pub fn with_streaming_func<F, Fut>(mut self, mut func: F) -> Self
83 where
84 F: FnMut(String) -> Fut + Send + 'static,
85 Fut: Future<Output = Result<(), ()>> + Send + 'static,
86 {
87 let func = Arc::new(Mutex::new(
88 move |s: String| -> Pin<Box<dyn Future<Output = Result<(), ()>> + Send>> {
89 Box::pin(func(s))
90 },
91 ));
92
93 self.streaming_func = Some(func);
94 self
95 }
96
97 pub fn with_top_k(mut self, top_k: usize) -> Self {
98 self.top_k = Some(top_k);
99 self
100 }
101
102 pub fn with_top_p(mut self, top_p: f32) -> Self {
103 self.top_p = Some(top_p);
104 self
105 }
106
107 pub fn with_seed(mut self, seed: usize) -> Self {
108 self.seed = Some(seed);
109 self
110 }
111
112 pub fn with_min_length(mut self, min_length: usize) -> Self {
113 self.min_length = Some(min_length);
114 self
115 }
116
117 pub fn with_max_length(mut self, max_length: usize) -> Self {
118 self.max_length = Some(max_length);
119 self
120 }
121
122 pub fn with_n(mut self, n: usize) -> Self {
123 self.n = Some(n);
124 self
125 }
126
127 pub fn with_repetition_penalty(mut self, repetition_penalty: f32) -> Self {
128 self.repetition_penalty = Some(repetition_penalty);
129 self
130 }
131
132 pub fn with_frequency_penalty(mut self, frequency_penalty: f32) -> Self {
133 self.frequency_penalty = Some(frequency_penalty);
134 self
135 }
136
137 pub fn with_presence_penalty(mut self, presence_penalty: f32) -> Self {
138 self.presence_penalty = Some(presence_penalty);
139 self
140 }
141
142 pub fn with_functions(mut self, functions: Vec<FunctionDefinition>) -> Self {
143 self.functions = Some(functions);
144 self
145 }
146
147 pub fn with_function_call_behavior(mut self, behavior: FunctionCallBehavior) -> Self {
148 self.function_call_behavior = Some(behavior);
149 self
150 }
151
152 pub fn with_stream_usage(mut self, stream_usage: bool) -> Self {
153 self.stream_usage = Some(stream_usage);
154 self
155 }
156
157 pub fn merge_options(&mut self, incoming_options: CallOptions) {
158 self.candidate_count = incoming_options.candidate_count.or(self.candidate_count);
160 self.max_tokens = incoming_options.max_tokens.or(self.max_tokens);
161 self.temperature = incoming_options.temperature.or(self.temperature);
162 self.top_k = incoming_options.top_k.or(self.top_k);
163 self.top_p = incoming_options.top_p.or(self.top_p);
164 self.seed = incoming_options.seed.or(self.seed);
165 self.min_length = incoming_options.min_length.or(self.min_length);
166 self.max_length = incoming_options.max_length.or(self.max_length);
167 self.n = incoming_options.n.or(self.n);
168 self.repetition_penalty = incoming_options
169 .repetition_penalty
170 .or(self.repetition_penalty);
171 self.frequency_penalty = incoming_options
172 .frequency_penalty
173 .or(self.frequency_penalty);
174 self.presence_penalty = incoming_options.presence_penalty.or(self.presence_penalty);
175 self.function_call_behavior = incoming_options
176 .function_call_behavior
177 .or(self.function_call_behavior.clone());
178 self.stream_usage = incoming_options.stream_usage.or(self.stream_usage);
179
180 if let Some(mut new_stop_words) = incoming_options.stop_words {
182 if let Some(existing_stop_words) = &mut self.stop_words {
183 existing_stop_words.append(&mut new_stop_words);
184 } else {
185 self.stop_words = Some(new_stop_words);
186 }
187 }
188
189 if let Some(mut incoming_functions) = incoming_options.functions {
191 if let Some(existing_functions) = &mut self.functions {
192 existing_functions.append(&mut incoming_functions);
193 } else {
194 self.functions = Some(incoming_functions);
195 }
196 }
197
198 self.streaming_func = incoming_options
201 .streaming_func
202 .or_else(|| self.streaming_func.clone());
203 }
204}