1use std::collections::HashMap;
2
3use schemars::JsonSchema;
4use serde::de::DeserializeOwned;
5use tools_rs::ToolCollection;
6
7#[cfg(feature = "stream")]
8use crate::chat::state::{InputStreamed, Streamed};
9use crate::{
10 chat::{
11 Chat,
12 state::{Embedded, Structured, Unstructured},
13 },
14 traits::CompletionProvider,
15 types::{
16 callback::{CallbackStrategy, RetryStrategy},
17 options::ChatOptions,
18 tools::{ScopedCollection, TypedCollection},
19 },
20};
21
22#[cfg(feature = "stream")]
23use crate::traits::StreamProvider;
24
25pub struct ChatBuilder<CP: CompletionProvider, Output = Unstructured> {
26 model: Option<CP>,
27 output_shape: Option<schemars::Schema>,
28 model_options: Option<ChatOptions>,
29 max_steps: Option<u16>,
30 max_retries: Option<u16>,
31 retry_strategy: Option<RetryStrategy>,
32 before_strategy: Option<CallbackStrategy>,
33 after_strategy: Option<CallbackStrategy>,
34 scoped_collections: Vec<Box<dyn TypedCollection>>,
35 _output: std::marker::PhantomData<Output>,
36}
37
38impl<CP: CompletionProvider> ChatBuilder<CP, Unstructured> {
39 pub fn new() -> Self {
40 ChatBuilder {
41 _output: std::marker::PhantomData,
42 ..Default::default()
43 }
44 }
45
46 pub fn with_structured_output<T>(self) -> ChatBuilder<CP, Structured<T>>
47 where
48 T: JsonSchema + DeserializeOwned,
49 {
50 let shape = schemars::schema_for!(T);
51
52 ChatBuilder {
53 model: self.model,
54 max_steps: self.max_steps,
55 max_retries: self.max_retries,
56 retry_strategy: self.retry_strategy,
57 before_strategy: self.before_strategy,
58 after_strategy: self.after_strategy,
59 output_shape: Some(shape),
60 scoped_collections: self.scoped_collections,
61 model_options: self.model_options,
62 _output: std::marker::PhantomData,
63 }
64 }
65
66 #[cfg(feature = "stream")]
67 pub fn with_streamed_response(self) -> ChatBuilder<CP, Streamed>
68 where
69 CP: StreamProvider,
70 {
71 if self.output_shape.is_some() {
72 println!(
73 "Warning: Cannot call streamed responses with structured outputs. Output shape will be set to None"
74 );
75 }
76
77 ChatBuilder {
78 model: self.model,
79 max_steps: self.max_steps,
80 max_retries: self.max_retries,
81 retry_strategy: self.retry_strategy,
82 before_strategy: self.before_strategy,
83 after_strategy: self.after_strategy,
84 output_shape: None,
85 scoped_collections: self.scoped_collections,
86 model_options: self.model_options,
87 _output: std::marker::PhantomData,
88 }
89 }
90
91 #[cfg(feature = "stream")]
101 pub fn with_input_stream(self) -> ChatBuilder<CP, InputStreamed>
102 where
103 CP: StreamProvider,
104 {
105 if self.output_shape.is_some() {
106 println!(
107 "Warning: Cannot call input-streamed responses with structured outputs. Output shape will be set to None"
108 );
109 }
110
111 ChatBuilder {
112 model: self.model,
113 max_steps: self.max_steps,
114 max_retries: self.max_retries,
115 retry_strategy: self.retry_strategy,
116 before_strategy: self.before_strategy,
117 after_strategy: self.after_strategy,
118 output_shape: None,
119 scoped_collections: self.scoped_collections,
120 model_options: self.model_options,
121 _output: std::marker::PhantomData,
122 }
123 }
124
125 pub fn with_embeddings(self) -> ChatBuilder<CP, Embedded> {
126 if self.output_shape.is_some() {
127 println!(
128 "Warning: Cannot call embedding responses with structured outputs. Output shape will be set to None"
129 );
130 }
131
132 ChatBuilder {
133 model: self.model,
134 max_retries: self.max_retries,
135 retry_strategy: self.retry_strategy,
136 before_strategy: self.before_strategy,
137 after_strategy: self.after_strategy,
138 output_shape: None,
139 scoped_collections: Vec::new(),
140 max_steps: None,
141 model_options: self.model_options,
142 _output: std::marker::PhantomData,
143 }
144 }
145}
146
147impl<CP: CompletionProvider, Output> ChatBuilder<CP, Output> {
148 pub fn with_max_steps(mut self, max_steps: u16) -> Self {
149 self.max_steps = Some(max_steps);
150 self
151 }
152
153 pub fn with_max_retries(mut self, max_retries: u16) -> Self {
154 self.max_retries = Some(max_retries);
155 self
156 }
157
158 pub fn with_tools(mut self, tools: ToolCollection) -> Self {
162 self.scoped_collections
163 .push(Box::new(ScopedCollection::auto_execute(tools)));
164 self
165 }
166
167 pub fn with_scoped_tools<M, F>(mut self, scoped: ScopedCollection<M, F>) -> Self
171 where
172 M: Send + Sync + 'static,
173 F: Fn(&tools_rs::FunctionCall, &M) -> crate::types::tools::Action + Send + Sync + 'static,
174 {
175 self.scoped_collections.push(Box::new(scoped));
176 self
177 }
178
179 pub fn with_retry_strategy(mut self, retry_strategy: RetryStrategy) -> Self {
180 self.retry_strategy = Some(retry_strategy);
181 self
182 }
183
184 pub fn with_model(mut self, model: CP) -> Self {
185 self.model = Some(model);
186 self
187 }
188
189 pub fn with_options(mut self, options: ChatOptions) -> Self {
190 self.model_options = Some(options);
191 self
192 }
193
194 pub fn build(self) -> Chat<CP, Output> {
195 let mut routing: HashMap<String, usize> = HashMap::new();
200 for (idx, coll) in self.scoped_collections.iter().enumerate() {
201 for name in coll.names() {
202 if routing.contains_key(name) {
203 eprintln!(
204 "chat-rs: tool name `{name}` is registered in multiple scoped \
205 collections; keeping the first registration."
206 );
207 continue;
208 }
209 routing.insert(name.to_string(), idx);
210 }
211 }
212
213 Chat {
214 model: self.model.expect("Need to set a model"),
215 output_shape: self.output_shape,
216 max_steps: self.max_steps,
217 max_retries: self.max_retries,
218 retry_strategy: self.retry_strategy,
219 before_strategy: self.before_strategy,
220 after_strategy: self.after_strategy,
221 scoped_collections: self.scoped_collections,
222 routing,
223 model_options: self.model_options,
224 _output: std::marker::PhantomData,
225 }
226 }
227}
228
229impl<CP: CompletionProvider> Default for ChatBuilder<CP, Unstructured> {
230 fn default() -> Self {
231 ChatBuilder {
232 model: None,
233 output_shape: None,
234 model_options: None,
235 max_steps: None,
236 max_retries: None,
237 retry_strategy: None,
238 before_strategy: None,
239 after_strategy: None,
240 scoped_collections: Vec::new(),
241 _output: std::marker::PhantomData,
242 }
243 }
244}