1use std::collections::HashMap;
2
3use schemars::JsonSchema;
4use serde::de::DeserializeOwned;
5use tools_rs::ToolCollection;
6
7#[cfg(feature = "stream")]
8use crate::chat::state::{InputStreamed, Streamed};
9use crate::{
10 chat::{
11 Chat,
12 state::{Embedded, Structured, Unstructured},
13 },
14 traits::CompletionProvider,
15 types::{
16 callback::{CallbackStrategy, RetryStrategy},
17 options::ChatOptions,
18 tools::{ScopedCollection, TypedCollection},
19 },
20};
21
22#[cfg(feature = "stream")]
23use crate::traits::StreamProvider;
24
25pub struct ChatBuilder<CP: CompletionProvider, Output = Unstructured> {
26 model: Option<CP>,
27 output_shape: Option<schemars::Schema>,
28 model_options: Option<ChatOptions>,
29 max_steps: Option<u16>,
30 max_retries: Option<u16>,
31 retry_strategy: Option<RetryStrategy>,
32 before_strategy: Option<CallbackStrategy>,
33 after_strategy: Option<CallbackStrategy>,
34 scoped_collections: Vec<Box<dyn TypedCollection>>,
35 _output: std::marker::PhantomData<Output>,
36}
37
38impl<CP: CompletionProvider> ChatBuilder<CP, Unstructured> {
39 pub fn new() -> Self {
40 ChatBuilder {
41 _output: std::marker::PhantomData,
42 ..Default::default()
43 }
44 }
45
46 pub fn with_structured_output<T>(self) -> ChatBuilder<CP, Structured<T>>
47 where
48 T: JsonSchema + DeserializeOwned,
49 {
50 let shape = schemars::schema_for!(T);
51
52 ChatBuilder {
53 model: self.model,
54 max_steps: self.max_steps,
55 max_retries: self.max_retries,
56 retry_strategy: self.retry_strategy,
57 before_strategy: self.before_strategy,
58 after_strategy: self.after_strategy,
59 output_shape: Some(shape),
60 scoped_collections: self.scoped_collections,
61 model_options: self.model_options,
62 _output: std::marker::PhantomData,
63 }
64 }
65
66 #[cfg(feature = "stream")]
67 pub fn with_streamed_response(self) -> ChatBuilder<CP, Streamed>
68 where
69 CP: StreamProvider,
70 {
71 if self.output_shape.is_some() {
72 println!(
73 "Warning: Cannot call streamed responses with structured outputs. Output shape will be set to None"
74 );
75 }
76
77 ChatBuilder {
78 model: self.model,
79 max_steps: self.max_steps,
80 max_retries: self.max_retries,
81 retry_strategy: self.retry_strategy,
82 before_strategy: self.before_strategy,
83 after_strategy: self.after_strategy,
84 output_shape: None,
85 scoped_collections: self.scoped_collections,
86 model_options: self.model_options,
87 _output: std::marker::PhantomData,
88 }
89 }
90
91 #[cfg(feature = "stream")]
103 pub fn with_input_stream<I>(self) -> ChatBuilder<CP, InputStreamed<I>>
104 where
105 CP: StreamProvider,
106 I: futures::Stream<Item = crate::types::messages::parts::PartEnum>
107 + Send
108 + Unpin
109 + 'static,
110 {
111 if self.output_shape.is_some() {
112 println!(
113 "Warning: Cannot call input-streamed responses with structured outputs. Output shape will be set to None"
114 );
115 }
116
117 ChatBuilder {
118 model: self.model,
119 max_steps: self.max_steps,
120 max_retries: self.max_retries,
121 retry_strategy: self.retry_strategy,
122 before_strategy: self.before_strategy,
123 after_strategy: self.after_strategy,
124 output_shape: None,
125 scoped_collections: self.scoped_collections,
126 model_options: self.model_options,
127 _output: std::marker::PhantomData,
128 }
129 }
130
131 pub fn with_embeddings(self) -> ChatBuilder<CP, Embedded> {
132 if self.output_shape.is_some() {
133 println!(
134 "Warning: Cannot call embedding responses with structured outputs. Output shape will be set to None"
135 );
136 }
137
138 ChatBuilder {
139 model: self.model,
140 max_retries: self.max_retries,
141 retry_strategy: self.retry_strategy,
142 before_strategy: self.before_strategy,
143 after_strategy: self.after_strategy,
144 output_shape: None,
145 scoped_collections: Vec::new(),
146 max_steps: None,
147 model_options: self.model_options,
148 _output: std::marker::PhantomData,
149 }
150 }
151}
152
153impl<CP: CompletionProvider, Output> ChatBuilder<CP, Output> {
154 pub fn with_max_steps(mut self, max_steps: u16) -> Self {
155 self.max_steps = Some(max_steps);
156 self
157 }
158
159 pub fn with_max_retries(mut self, max_retries: u16) -> Self {
160 self.max_retries = Some(max_retries);
161 self
162 }
163
164 pub fn with_tools(mut self, tools: ToolCollection) -> Self {
168 self.scoped_collections
169 .push(Box::new(ScopedCollection::auto_execute(tools)));
170 self
171 }
172
173 pub fn with_scoped_tools<M, F>(mut self, scoped: ScopedCollection<M, F>) -> Self
177 where
178 M: Send + Sync + 'static,
179 F: Fn(&tools_rs::FunctionCall, &M) -> crate::types::tools::Action + Send + Sync + 'static,
180 {
181 self.scoped_collections.push(Box::new(scoped));
182 self
183 }
184
185 pub fn with_retry_strategy(mut self, retry_strategy: RetryStrategy) -> Self {
186 self.retry_strategy = Some(retry_strategy);
187 self
188 }
189
190 pub fn with_model(mut self, model: CP) -> Self {
191 self.model = Some(model);
192 self
193 }
194
195 pub fn with_options(mut self, options: ChatOptions) -> Self {
196 self.model_options = Some(options);
197 self
198 }
199
200 pub fn build(self) -> Chat<CP, Output> {
201 let mut routing: HashMap<String, usize> = HashMap::new();
206 for (idx, coll) in self.scoped_collections.iter().enumerate() {
207 for name in coll.names() {
208 if routing.contains_key(name) {
209 eprintln!(
210 "chat-rs: tool name `{name}` is registered in multiple scoped \
211 collections; keeping the first registration."
212 );
213 continue;
214 }
215 routing.insert(name.to_string(), idx);
216 }
217 }
218
219 Chat {
220 model: self.model.expect("Need to set a model"),
221 output_shape: self.output_shape,
222 max_steps: self.max_steps,
223 max_retries: self.max_retries,
224 retry_strategy: self.retry_strategy,
225 before_strategy: self.before_strategy,
226 after_strategy: self.after_strategy,
227 scoped_collections: self.scoped_collections,
228 routing,
229 model_options: self.model_options,
230 _output: std::marker::PhantomData,
231 }
232 }
233}
234
235impl<CP: CompletionProvider> Default for ChatBuilder<CP, Unstructured> {
236 fn default() -> Self {
237 ChatBuilder {
238 model: None,
239 output_shape: None,
240 model_options: None,
241 max_steps: None,
242 max_retries: None,
243 retry_strategy: None,
244 before_strategy: None,
245 after_strategy: None,
246 scoped_collections: Vec::new(),
247 _output: std::marker::PhantomData,
248 }
249 }
250}