1use std::collections::HashMap;
2
3use schemars::JsonSchema;
4use serde::de::DeserializeOwned;
5use tools_rs::ToolCollection;
6
7#[cfg(feature = "stream")]
8use crate::chat::state::InputStreamed;
9use crate::{
10 chat::{
11 Chat,
12 state::{Embedded, Structured, Unstructured},
13 },
14 traits::CompletionProvider,
15 types::{
16 callback::{CallbackStrategy, RetryStrategy},
17 options::ChatOptions,
18 tools::{ScopedCollection, TypedCollection},
19 },
20};
21
22#[cfg(feature = "stream")]
23use crate::traits::StreamProvider;
24
25pub struct ChatBuilder<CP: CompletionProvider, Output = Unstructured> {
26 model: Option<CP>,
27 output_shape: Option<schemars::Schema>,
28 model_options: Option<ChatOptions>,
29 max_steps: Option<u16>,
30 max_retries: Option<u16>,
31 retry_strategy: Option<RetryStrategy>,
32 before_strategy: Option<CallbackStrategy>,
33 after_strategy: Option<CallbackStrategy>,
34 scoped_collections: Vec<Box<dyn TypedCollection>>,
35 _output: std::marker::PhantomData<Output>,
36}
37
38impl<CP: CompletionProvider> ChatBuilder<CP, Unstructured> {
39 pub fn new() -> Self {
40 ChatBuilder {
41 _output: std::marker::PhantomData,
42 ..Default::default()
43 }
44 }
45
46 pub fn with_structured_output<T>(self) -> ChatBuilder<CP, Structured<T>>
47 where
48 T: JsonSchema + DeserializeOwned,
49 {
50 let shape = schemars::schema_for!(T);
51
52 ChatBuilder {
53 model: self.model,
54 max_steps: self.max_steps,
55 max_retries: self.max_retries,
56 retry_strategy: self.retry_strategy,
57 before_strategy: self.before_strategy,
58 after_strategy: self.after_strategy,
59 output_shape: Some(shape),
60 scoped_collections: self.scoped_collections,
61 model_options: self.model_options,
62 _output: std::marker::PhantomData,
63 }
64 }
65
66 #[cfg(feature = "stream")]
76 pub fn with_input_stream(self) -> ChatBuilder<CP, InputStreamed>
77 where
78 CP: StreamProvider,
79 {
80 if self.output_shape.is_some() {
81 println!(
82 "Warning: Cannot call input-streamed responses with structured outputs. Output shape will be set to None"
83 );
84 }
85
86 ChatBuilder {
87 model: self.model,
88 max_steps: self.max_steps,
89 max_retries: self.max_retries,
90 retry_strategy: self.retry_strategy,
91 before_strategy: self.before_strategy,
92 after_strategy: self.after_strategy,
93 output_shape: None,
94 scoped_collections: self.scoped_collections,
95 model_options: self.model_options,
96 _output: std::marker::PhantomData,
97 }
98 }
99
100 pub fn with_embeddings(self) -> ChatBuilder<CP, Embedded> {
101 if self.output_shape.is_some() {
102 println!(
103 "Warning: Cannot call embedding responses with structured outputs. Output shape will be set to None"
104 );
105 }
106
107 ChatBuilder {
108 model: self.model,
109 max_retries: self.max_retries,
110 retry_strategy: self.retry_strategy,
111 before_strategy: self.before_strategy,
112 after_strategy: self.after_strategy,
113 output_shape: None,
114 scoped_collections: Vec::new(),
115 max_steps: None,
116 model_options: self.model_options,
117 _output: std::marker::PhantomData,
118 }
119 }
120}
121
122impl<CP: CompletionProvider, Output> ChatBuilder<CP, Output> {
123 pub fn with_max_steps(mut self, max_steps: u16) -> Self {
124 self.max_steps = Some(max_steps);
125 self
126 }
127
128 pub fn with_max_retries(mut self, max_retries: u16) -> Self {
129 self.max_retries = Some(max_retries);
130 self
131 }
132
133 pub fn with_tools(mut self, tools: ToolCollection) -> Self {
137 self.scoped_collections
138 .push(Box::new(ScopedCollection::auto_execute(tools)));
139 self
140 }
141
142 pub fn with_scoped_tools<M, F>(mut self, scoped: ScopedCollection<M, F>) -> Self
146 where
147 M: Send + Sync + 'static,
148 F: Fn(&tools_rs::FunctionCall, &M) -> crate::types::tools::Action + Send + Sync + 'static,
149 {
150 self.scoped_collections.push(Box::new(scoped));
151 self
152 }
153
154 pub fn with_retry_strategy(mut self, retry_strategy: RetryStrategy) -> Self {
155 self.retry_strategy = Some(retry_strategy);
156 self
157 }
158
159 pub fn with_before_strategy(mut self, before_strategy: CallbackStrategy) -> Self {
163 self.before_strategy = Some(before_strategy);
164 self
165 }
166
167 pub fn with_after_strategy(mut self, after_strategy: CallbackStrategy) -> Self {
170 self.after_strategy = Some(after_strategy);
171 self
172 }
173
174 pub fn with_model(mut self, model: CP) -> Self {
175 self.model = Some(model);
176 self
177 }
178
179 pub fn with_options(mut self, options: ChatOptions) -> Self {
180 self.model_options = Some(options);
181 self
182 }
183
184 pub fn build(self) -> Chat<CP, Output> {
185 let mut routing: HashMap<String, usize> = HashMap::new();
186 for (idx, coll) in self.scoped_collections.iter().enumerate() {
187 for name in coll.names() {
188 if routing.contains_key(name) {
189 eprintln!(
190 "chat-rs: tool name `{name}` is registered in multiple scoped \
191 collections; keeping the first registration."
192 );
193 continue;
194 }
195 routing.insert(name.to_string(), idx);
196 }
197 }
198
199 Chat {
200 model: self.model.expect("Need to set a model"),
201 output_shape: self.output_shape,
202 max_steps: self.max_steps,
203 max_retries: self.max_retries,
204 retry_strategy: self.retry_strategy,
205 before_strategy: self.before_strategy,
206 after_strategy: self.after_strategy,
207 scoped_collections: self.scoped_collections,
208 routing,
209 model_options: self.model_options,
210 _output: std::marker::PhantomData,
211 }
212 }
213}
214
215impl<CP: CompletionProvider> Default for ChatBuilder<CP, Unstructured> {
216 fn default() -> Self {
217 ChatBuilder {
218 model: None,
219 output_shape: None,
220 model_options: None,
221 max_steps: None,
222 max_retries: None,
223 retry_strategy: None,
224 before_strategy: None,
225 after_strategy: None,
226 scoped_collections: Vec::new(),
227 _output: std::marker::PhantomData,
228 }
229 }
230}