langchain_rust/chain/
options.rs

1use futures::Future;
2use std::pin::Pin;
3
4use crate::language_models::options::CallOptions;
5
6pub struct ChainCallOptions {
7    pub max_tokens: Option<u32>,
8    pub temperature: Option<f32>,
9    pub stop_words: Option<Vec<String>>,
10    pub streaming_func: Option<
11        Box<dyn FnMut(String) -> Pin<Box<dyn Future<Output = Result<(), ()>> + Send>> + Send>,
12    >,
13    pub top_k: Option<usize>,
14    pub top_p: Option<f32>,
15    pub seed: Option<usize>,
16    pub min_length: Option<usize>,
17    pub max_length: Option<usize>,
18    pub repetition_penalty: Option<f32>,
19}
20
21impl Default for ChainCallOptions {
22    fn default() -> Self {
23        Self::new()
24    }
25}
26
27impl ChainCallOptions {
28    pub fn new() -> Self {
29        Self {
30            max_tokens: None,
31            temperature: None,
32            stop_words: None,
33            streaming_func: None,
34            top_k: None,
35            top_p: None,
36            seed: None,
37            min_length: None,
38            max_length: None,
39            repetition_penalty: None,
40        }
41    }
42
43    pub fn to_llm_options(options: ChainCallOptions) -> CallOptions {
44        let mut llm_option = CallOptions::new();
45        if let Some(max_tokens) = options.max_tokens {
46            llm_option = llm_option.with_max_tokens(max_tokens);
47        }
48        if let Some(temperature) = options.temperature {
49            llm_option = llm_option.with_temperature(temperature);
50        }
51        if let Some(stop_words) = options.stop_words {
52            llm_option = llm_option.with_stop_words(stop_words);
53        }
54        if let Some(top_k) = options.top_k {
55            llm_option = llm_option.with_top_k(top_k);
56        }
57        if let Some(top_p) = options.top_p {
58            llm_option = llm_option.with_top_p(top_p);
59        }
60        if let Some(seed) = options.seed {
61            llm_option = llm_option.with_seed(seed);
62        }
63        if let Some(min_length) = options.min_length {
64            llm_option = llm_option.with_min_length(min_length);
65        }
66        if let Some(max_length) = options.max_length {
67            llm_option = llm_option.with_max_length(max_length);
68        }
69        if let Some(repetition_penalty) = options.repetition_penalty {
70            llm_option = llm_option.with_repetition_penalty(repetition_penalty);
71        }
72
73        if let Some(streaming_func) = options.streaming_func {
74            llm_option = llm_option.with_streaming_func(streaming_func)
75        }
76        llm_option
77    }
78
79    pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
80        self.max_tokens = Some(max_tokens);
81        self
82    }
83
84    pub fn with_temperature(mut self, temperature: f32) -> Self {
85        self.temperature = Some(temperature);
86        self
87    }
88
89    pub fn with_stop_words(mut self, stop_words: Vec<String>) -> Self {
90        self.stop_words = Some(stop_words);
91        self
92    }
93
94    //TODO:Check if this should be a &str instead of a String
95    pub fn with_streaming_func<F, Fut>(mut self, mut func: F) -> Self
96    where
97        F: FnMut(String) -> Fut + Send + 'static,
98        Fut: Future<Output = Result<(), ()>> + Send + 'static,
99    {
100        self.streaming_func = Some(Box::new(move |s: String| Box::pin(func(s))));
101        self
102    }
103
104    pub fn with_top_k(mut self, top_k: usize) -> Self {
105        self.top_k = Some(top_k);
106        self
107    }
108
109    pub fn with_top_p(mut self, top_p: f32) -> Self {
110        self.top_p = Some(top_p);
111        self
112    }
113
114    pub fn with_seed(mut self, seed: usize) -> Self {
115        self.seed = Some(seed);
116        self
117    }
118
119    pub fn with_min_length(mut self, min_length: usize) -> Self {
120        self.min_length = Some(min_length);
121        self
122    }
123
124    pub fn with_max_length(mut self, max_length: usize) -> Self {
125        self.max_length = Some(max_length);
126        self
127    }
128
129    pub fn with_repetition_penalty(mut self, repetition_penalty: f32) -> Self {
130        self.repetition_penalty = Some(repetition_penalty);
131        self
132    }
133}