Skip to main content

chat_core/
builder.rs

1use schemars::JsonSchema;
2use serde::de::DeserializeOwned;
3use tools_rs::ToolCollection;
4
5use crate::{
6    chat::{
7        Chat,
8        state::{Embedded, Streamed, Structured, Unstructured},
9    },
10    traits::CompletionProvider,
11    types::{
12        callback::{CallbackStrategy, RetryStrategy},
13        options::ChatOptions,
14    },
15};
16
17#[cfg(feature = "stream")]
18use crate::traits::StreamProvider;
19
20pub struct ChatBuilder<CP: CompletionProvider, Output = Unstructured> {
21    model: Option<CP>,
22    output_shape: Option<schemars::Schema>,
23    model_options: Option<ChatOptions>,
24    max_steps: Option<u16>,
25    max_retries: Option<u16>,
26    retry_strategy: Option<RetryStrategy>,
27    before_strategy: Option<CallbackStrategy>,
28    after_strategy: Option<CallbackStrategy>,
29    tools: Option<ToolCollection>,
30    _output: std::marker::PhantomData<Output>,
31}
32
33impl<CP: CompletionProvider> ChatBuilder<CP, Unstructured> {
34    pub fn new() -> Self {
35        ChatBuilder {
36            _output: std::marker::PhantomData,
37            ..Default::default()
38        }
39    }
40
41    pub fn with_structured_output<T>(self) -> ChatBuilder<CP, Structured<T>>
42    where
43        T: JsonSchema + DeserializeOwned,
44    {
45        let shape = schemars::schema_for!(T);
46
47        ChatBuilder {
48            model: self.model,
49            max_steps: self.max_steps,
50            max_retries: self.max_retries,
51            retry_strategy: self.retry_strategy,
52            before_strategy: self.before_strategy,
53            after_strategy: self.after_strategy,
54            output_shape: Some(shape),
55            tools: self.tools,
56            model_options: self.model_options,
57            _output: std::marker::PhantomData,
58        }
59    }
60
61    #[cfg(feature = "stream")]
62    pub fn with_streamed_response(self) -> ChatBuilder<CP, Streamed>
63    where
64        CP: StreamProvider,
65    {
66        if self.output_shape.is_some() {
67            println!(
68                "Warning: Cannot call streamed responses with structured outputs. Output shape will be set to None"
69            );
70        }
71
72        ChatBuilder {
73            model: self.model,
74            max_steps: self.max_steps,
75            max_retries: self.max_retries,
76            retry_strategy: self.retry_strategy,
77            before_strategy: self.before_strategy,
78            after_strategy: self.after_strategy,
79            output_shape: None, // No shape for pure streaming
80            tools: self.tools,
81            model_options: self.model_options,
82            _output: std::marker::PhantomData,
83        }
84    }
85
86    pub fn with_embeddings(self) -> ChatBuilder<CP, Embedded> {
87        if self.output_shape.is_some() {
88            println!(
89                "Warning: Cannot call embedding responses with structured outputs. Output shape will be set to None"
90            );
91        }
92
93        ChatBuilder {
94            model: self.model,
95            max_retries: self.max_retries,
96            retry_strategy: self.retry_strategy,
97            before_strategy: self.before_strategy,
98            after_strategy: self.after_strategy,
99            output_shape: None,
100            tools: None,
101            max_steps: None,
102            model_options: self.model_options,
103            _output: std::marker::PhantomData,
104        }
105    }
106}
107
108impl<CP: CompletionProvider, Output> ChatBuilder<CP, Output> {
109    pub fn with_max_steps(mut self, max_steps: u16) -> Self {
110        self.max_steps = Some(max_steps);
111        self
112    }
113
114    pub fn with_max_retries(mut self, max_retries: u16) -> Self {
115        self.max_retries = Some(max_retries);
116        self
117    }
118
119    pub fn with_tools(mut self, tools: ToolCollection) -> Self {
120        self.tools = Some(tools);
121        self
122    }
123
124    pub fn with_retry_strategy(mut self, retry_strategy: RetryStrategy) -> Self {
125        self.retry_strategy = Some(retry_strategy);
126        self
127    }
128
129    pub fn with_model(mut self, model: CP) -> Self {
130        self.model = Some(model);
131        self
132    }
133
134    pub fn with_options(mut self, options: ChatOptions) -> Self {
135        self.model_options = Some(options);
136        self
137    }
138
139    pub fn build(self) -> Chat<CP, Output> {
140        Chat {
141            model: self.model.expect("Need to set a model"),
142            output_shape: self.output_shape,
143            max_steps: self.max_steps,
144            max_retries: self.max_retries,
145            retry_strategy: self.retry_strategy,
146            before_strategy: self.before_strategy,
147            after_strategy: self.after_strategy,
148            tools: self.tools,
149            model_options: self.model_options,
150            _output: std::marker::PhantomData,
151        }
152    }
153}
154
155impl<CP: CompletionProvider> Default for ChatBuilder<CP, Unstructured> {
156    fn default() -> Self {
157        ChatBuilder {
158            model: None,
159            output_shape: None,
160            model_options: None,
161            max_steps: None,
162            max_retries: None,
163            retry_strategy: None,
164            before_strategy: None,
165            after_strategy: None,
166            tools: None,
167            _output: std::marker::PhantomData,
168        }
169    }
170}