1use std::collections::HashMap;
2
3use schemars::JsonSchema;
4use serde::de::DeserializeOwned;
5use tools_rs::ToolCollection;
6
7#[cfg(feature = "stream")]
8use crate::chat::state::Streamed;
9use crate::{
10 chat::{
11 Chat,
12 state::{Embedded, Structured, Unstructured},
13 },
14 traits::CompletionProvider,
15 types::{
16 callback::{CallbackStrategy, RetryStrategy},
17 options::ChatOptions,
18 tools::{ScopedCollection, TypedCollection},
19 },
20};
21
22#[cfg(feature = "stream")]
23use crate::traits::StreamProvider;
24
25pub struct ChatBuilder<CP: CompletionProvider, Output = Unstructured> {
26 model: Option<CP>,
27 output_shape: Option<schemars::Schema>,
28 model_options: Option<ChatOptions>,
29 max_steps: Option<u16>,
30 max_retries: Option<u16>,
31 retry_strategy: Option<RetryStrategy>,
32 before_strategy: Option<CallbackStrategy>,
33 after_strategy: Option<CallbackStrategy>,
34 scoped_collections: Vec<Box<dyn TypedCollection>>,
35 _output: std::marker::PhantomData<Output>,
36}
37
38impl<CP: CompletionProvider> ChatBuilder<CP, Unstructured> {
39 pub fn new() -> Self {
40 ChatBuilder {
41 _output: std::marker::PhantomData,
42 ..Default::default()
43 }
44 }
45
46 pub fn with_structured_output<T>(self) -> ChatBuilder<CP, Structured<T>>
47 where
48 T: JsonSchema + DeserializeOwned,
49 {
50 let shape = schemars::schema_for!(T);
51
52 ChatBuilder {
53 model: self.model,
54 max_steps: self.max_steps,
55 max_retries: self.max_retries,
56 retry_strategy: self.retry_strategy,
57 before_strategy: self.before_strategy,
58 after_strategy: self.after_strategy,
59 output_shape: Some(shape),
60 scoped_collections: self.scoped_collections,
61 model_options: self.model_options,
62 _output: std::marker::PhantomData,
63 }
64 }
65
66 #[cfg(feature = "stream")]
67 pub fn with_streamed_response(self) -> ChatBuilder<CP, Streamed>
68 where
69 CP: StreamProvider,
70 {
71 if self.output_shape.is_some() {
72 println!(
73 "Warning: Cannot call streamed responses with structured outputs. Output shape will be set to None"
74 );
75 }
76
77 ChatBuilder {
78 model: self.model,
79 max_steps: self.max_steps,
80 max_retries: self.max_retries,
81 retry_strategy: self.retry_strategy,
82 before_strategy: self.before_strategy,
83 after_strategy: self.after_strategy,
84 output_shape: None,
85 scoped_collections: self.scoped_collections,
86 model_options: self.model_options,
87 _output: std::marker::PhantomData,
88 }
89 }
90
91 pub fn with_embeddings(self) -> ChatBuilder<CP, Embedded> {
92 if self.output_shape.is_some() {
93 println!(
94 "Warning: Cannot call embedding responses with structured outputs. Output shape will be set to None"
95 );
96 }
97
98 ChatBuilder {
99 model: self.model,
100 max_retries: self.max_retries,
101 retry_strategy: self.retry_strategy,
102 before_strategy: self.before_strategy,
103 after_strategy: self.after_strategy,
104 output_shape: None,
105 scoped_collections: Vec::new(),
106 max_steps: None,
107 model_options: self.model_options,
108 _output: std::marker::PhantomData,
109 }
110 }
111}
112
113impl<CP: CompletionProvider, Output> ChatBuilder<CP, Output> {
114 pub fn with_max_steps(mut self, max_steps: u16) -> Self {
115 self.max_steps = Some(max_steps);
116 self
117 }
118
119 pub fn with_max_retries(mut self, max_retries: u16) -> Self {
120 self.max_retries = Some(max_retries);
121 self
122 }
123
124 pub fn with_tools(mut self, tools: ToolCollection) -> Self {
128 self.scoped_collections
129 .push(Box::new(ScopedCollection::auto_execute(tools)));
130 self
131 }
132
133 pub fn with_scoped_tools<M, F>(mut self, scoped: ScopedCollection<M, F>) -> Self
137 where
138 M: Send + Sync + 'static,
139 F: Fn(&tools_rs::FunctionCall, &M) -> crate::types::tools::Action + Send + Sync + 'static,
140 {
141 self.scoped_collections.push(Box::new(scoped));
142 self
143 }
144
145 pub fn with_retry_strategy(mut self, retry_strategy: RetryStrategy) -> Self {
146 self.retry_strategy = Some(retry_strategy);
147 self
148 }
149
150 pub fn with_model(mut self, model: CP) -> Self {
151 self.model = Some(model);
152 self
153 }
154
155 pub fn with_options(mut self, options: ChatOptions) -> Self {
156 self.model_options = Some(options);
157 self
158 }
159
160 pub fn build(self) -> Chat<CP, Output> {
161 let mut routing: HashMap<String, usize> = HashMap::new();
166 for (idx, coll) in self.scoped_collections.iter().enumerate() {
167 for name in coll.names() {
168 if routing.contains_key(name) {
169 eprintln!(
170 "chat-rs: tool name `{name}` is registered in multiple scoped \
171 collections; keeping the first registration."
172 );
173 continue;
174 }
175 routing.insert(name.to_string(), idx);
176 }
177 }
178
179 Chat {
180 model: self.model.expect("Need to set a model"),
181 output_shape: self.output_shape,
182 max_steps: self.max_steps,
183 max_retries: self.max_retries,
184 retry_strategy: self.retry_strategy,
185 before_strategy: self.before_strategy,
186 after_strategy: self.after_strategy,
187 scoped_collections: self.scoped_collections,
188 routing,
189 model_options: self.model_options,
190 _output: std::marker::PhantomData,
191 }
192 }
193}
194
195impl<CP: CompletionProvider> Default for ChatBuilder<CP, Unstructured> {
196 fn default() -> Self {
197 ChatBuilder {
198 model: None,
199 output_shape: None,
200 model_options: None,
201 max_steps: None,
202 max_retries: None,
203 retry_strategy: None,
204 before_strategy: None,
205 after_strategy: None,
206 scoped_collections: Vec::new(),
207 _output: std::marker::PhantomData,
208 }
209 }
210}