langchain_rust/chain/
options.rs1use futures::Future;
2use std::pin::Pin;
3
4use crate::language_models::options::CallOptions;
5
6pub struct ChainCallOptions {
7 pub max_tokens: Option<u32>,
8 pub temperature: Option<f32>,
9 pub stop_words: Option<Vec<String>>,
10 pub streaming_func: Option<
11 Box<dyn FnMut(String) -> Pin<Box<dyn Future<Output = Result<(), ()>> + Send>> + Send>,
12 >,
13 pub top_k: Option<usize>,
14 pub top_p: Option<f32>,
15 pub seed: Option<usize>,
16 pub min_length: Option<usize>,
17 pub max_length: Option<usize>,
18 pub repetition_penalty: Option<f32>,
19}
20
21impl Default for ChainCallOptions {
22 fn default() -> Self {
23 Self::new()
24 }
25}
26
27impl ChainCallOptions {
28 pub fn new() -> Self {
29 Self {
30 max_tokens: None,
31 temperature: None,
32 stop_words: None,
33 streaming_func: None,
34 top_k: None,
35 top_p: None,
36 seed: None,
37 min_length: None,
38 max_length: None,
39 repetition_penalty: None,
40 }
41 }
42
43 pub fn to_llm_options(options: ChainCallOptions) -> CallOptions {
44 let mut llm_option = CallOptions::new();
45 if let Some(max_tokens) = options.max_tokens {
46 llm_option = llm_option.with_max_tokens(max_tokens);
47 }
48 if let Some(temperature) = options.temperature {
49 llm_option = llm_option.with_temperature(temperature);
50 }
51 if let Some(stop_words) = options.stop_words {
52 llm_option = llm_option.with_stop_words(stop_words);
53 }
54 if let Some(top_k) = options.top_k {
55 llm_option = llm_option.with_top_k(top_k);
56 }
57 if let Some(top_p) = options.top_p {
58 llm_option = llm_option.with_top_p(top_p);
59 }
60 if let Some(seed) = options.seed {
61 llm_option = llm_option.with_seed(seed);
62 }
63 if let Some(min_length) = options.min_length {
64 llm_option = llm_option.with_min_length(min_length);
65 }
66 if let Some(max_length) = options.max_length {
67 llm_option = llm_option.with_max_length(max_length);
68 }
69 if let Some(repetition_penalty) = options.repetition_penalty {
70 llm_option = llm_option.with_repetition_penalty(repetition_penalty);
71 }
72
73 if let Some(streaming_func) = options.streaming_func {
74 llm_option = llm_option.with_streaming_func(streaming_func)
75 }
76 llm_option
77 }
78
79 pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
80 self.max_tokens = Some(max_tokens);
81 self
82 }
83
84 pub fn with_temperature(mut self, temperature: f32) -> Self {
85 self.temperature = Some(temperature);
86 self
87 }
88
89 pub fn with_stop_words(mut self, stop_words: Vec<String>) -> Self {
90 self.stop_words = Some(stop_words);
91 self
92 }
93
94 pub fn with_streaming_func<F, Fut>(mut self, mut func: F) -> Self
96 where
97 F: FnMut(String) -> Fut + Send + 'static,
98 Fut: Future<Output = Result<(), ()>> + Send + 'static,
99 {
100 self.streaming_func = Some(Box::new(move |s: String| Box::pin(func(s))));
101 self
102 }
103
104 pub fn with_top_k(mut self, top_k: usize) -> Self {
105 self.top_k = Some(top_k);
106 self
107 }
108
109 pub fn with_top_p(mut self, top_p: f32) -> Self {
110 self.top_p = Some(top_p);
111 self
112 }
113
114 pub fn with_seed(mut self, seed: usize) -> Self {
115 self.seed = Some(seed);
116 self
117 }
118
119 pub fn with_min_length(mut self, min_length: usize) -> Self {
120 self.min_length = Some(min_length);
121 self
122 }
123
124 pub fn with_max_length(mut self, max_length: usize) -> Self {
125 self.max_length = Some(max_length);
126 self
127 }
128
129 pub fn with_repetition_penalty(mut self, repetition_penalty: f32) -> Self {
130 self.repetition_penalty = Some(repetition_penalty);
131 self
132 }
133}