1use schemars::JsonSchema;
2use serde::de::DeserializeOwned;
3use tools_rs::ToolCollection;
4
5use crate::{
6 chat::{
7 Chat,
8 state::{Embedded, Streamed, Structured, Unstructured},
9 },
10 traits::CompletionProvider,
11 types::{
12 callback::{CallbackStrategy, RetryStrategy},
13 options::ChatOptions,
14 },
15};
16
17#[cfg(feature = "stream")]
18use crate::traits::StreamProvider;
19
20pub struct ChatBuilder<CP: CompletionProvider, Output = Unstructured> {
21 model: Option<CP>,
22 output_shape: Option<schemars::Schema>,
23 model_options: Option<ChatOptions>,
24 max_steps: Option<u16>,
25 max_retries: Option<u16>,
26 retry_strategy: Option<RetryStrategy>,
27 before_strategy: Option<CallbackStrategy>,
28 after_strategy: Option<CallbackStrategy>,
29 tools: Option<ToolCollection>,
30 _output: std::marker::PhantomData<Output>,
31}
32
33impl<CP: CompletionProvider> ChatBuilder<CP, Unstructured> {
34 pub fn new() -> Self {
35 ChatBuilder {
36 _output: std::marker::PhantomData,
37 ..Default::default()
38 }
39 }
40
41 pub fn with_structured_output<T>(self) -> ChatBuilder<CP, Structured<T>>
42 where
43 T: JsonSchema + DeserializeOwned,
44 {
45 let shape = schemars::schema_for!(T);
46
47 ChatBuilder {
48 model: self.model,
49 max_steps: self.max_steps,
50 max_retries: self.max_retries,
51 retry_strategy: self.retry_strategy,
52 before_strategy: self.before_strategy,
53 after_strategy: self.after_strategy,
54 output_shape: Some(shape),
55 tools: self.tools,
56 model_options: self.model_options,
57 _output: std::marker::PhantomData,
58 }
59 }
60
61 #[cfg(feature = "stream")]
62 pub fn with_streamed_response(self) -> ChatBuilder<CP, Streamed>
63 where
64 CP: StreamProvider,
65 {
66 if self.output_shape.is_some() {
67 println!(
68 "Warning: Cannot call streamed responses with structured outputs. Output shape will be set to None"
69 );
70 }
71
72 ChatBuilder {
73 model: self.model,
74 max_steps: self.max_steps,
75 max_retries: self.max_retries,
76 retry_strategy: self.retry_strategy,
77 before_strategy: self.before_strategy,
78 after_strategy: self.after_strategy,
79 output_shape: None, tools: self.tools,
81 model_options: self.model_options,
82 _output: std::marker::PhantomData,
83 }
84 }
85
86 pub fn with_embeddings(self) -> ChatBuilder<CP, Embedded> {
87 if self.output_shape.is_some() {
88 println!(
89 "Warning: Cannot call embedding responses with structured outputs. Output shape will be set to None"
90 );
91 }
92
93 ChatBuilder {
94 model: self.model,
95 max_retries: self.max_retries,
96 retry_strategy: self.retry_strategy,
97 before_strategy: self.before_strategy,
98 after_strategy: self.after_strategy,
99 output_shape: None,
100 tools: None,
101 max_steps: None,
102 model_options: self.model_options,
103 _output: std::marker::PhantomData,
104 }
105 }
106}
107
108impl<CP: CompletionProvider, Output> ChatBuilder<CP, Output> {
109 pub fn with_max_steps(mut self, max_steps: u16) -> Self {
110 self.max_steps = Some(max_steps);
111 self
112 }
113
114 pub fn with_max_retries(mut self, max_retries: u16) -> Self {
115 self.max_retries = Some(max_retries);
116 self
117 }
118
119 pub fn with_tools(mut self, tools: ToolCollection) -> Self {
120 self.tools = Some(tools);
121 self
122 }
123
124 pub fn with_retry_strategy(mut self, retry_strategy: RetryStrategy) -> Self {
125 self.retry_strategy = Some(retry_strategy);
126 self
127 }
128
129 pub fn with_model(mut self, model: CP) -> Self {
130 self.model = Some(model);
131 self
132 }
133
134 pub fn with_options(mut self, options: ChatOptions) -> Self {
135 self.model_options = Some(options);
136 self
137 }
138
139 pub fn build(self) -> Chat<CP, Output> {
140 Chat {
141 model: self.model.expect("Need to set a model"),
142 output_shape: self.output_shape,
143 max_steps: self.max_steps,
144 max_retries: self.max_retries,
145 retry_strategy: self.retry_strategy,
146 before_strategy: self.before_strategy,
147 after_strategy: self.after_strategy,
148 tools: self.tools,
149 model_options: self.model_options,
150 _output: std::marker::PhantomData,
151 }
152 }
153}
154
155impl<CP: CompletionProvider> Default for ChatBuilder<CP, Unstructured> {
156 fn default() -> Self {
157 ChatBuilder {
158 model: None,
159 output_shape: None,
160 model_options: None,
161 max_steps: None,
162 max_retries: None,
163 retry_strategy: None,
164 before_strategy: None,
165 after_strategy: None,
166 tools: None,
167 _output: std::marker::PhantomData,
168 }
169 }
170}