1use crate::anthropic::v1::request::AnthropicMessageRequestV1;
2use crate::anthropic::MessageParam;
3use crate::google::v1::generate::request::GeminiGenerateContentRequestV1;
4use crate::google::GeminiContent;
5use crate::openai::v1::chat::request::OpenAIChatCompletionRequestV1;
6use crate::openai::ChatMessage;
7use crate::prompt::types::MessageNum;
8use crate::prompt::ModelSettings;
9use crate::tools::AgentToolDefinition;
10use crate::traits::RequestAdapter;
11use crate::{Provider, TypeError};
12use potatohead_macro::dispatch_trait_method;
13use pyo3::types::PyList;
14use pyo3::types::PyListMethods;
15use pyo3::Python;
16use pyo3::{Bound, PyAny};
17use serde::{Deserialize, Serialize};
18use serde_json::Value;
19
20pub trait RequestBuilder {
22 type Request: Serialize;
23
24 fn build_request(
26 messages: Vec<MessageNum>,
27 system_instructions: Vec<MessageNum>,
28 model: String,
29 settings: ModelSettings,
30 response_format: Option<Value>,
31 ) -> Result<Self::Request, TypeError>;
32}
33
34#[derive(Debug, Clone, Copy)]
36pub enum RequestType {
37 OpenAIChatV1,
38 AnthropicMessageV1,
39 GeminiContentV1,
40}
41
42impl MessageNum {
43 pub fn request_type(&self) -> RequestType {
45 match self {
46 MessageNum::OpenAIMessageV1(_) => RequestType::OpenAIChatV1,
47 MessageNum::AnthropicMessageV1(_) => RequestType::AnthropicMessageV1,
48 MessageNum::GeminiContentV1(_) => RequestType::GeminiContentV1,
49 MessageNum::AnthropicSystemMessageV1(_) => RequestType::AnthropicMessageV1,
50 }
51 }
52}
53
54#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
57#[serde(untagged)]
58pub enum ProviderRequest {
59 OpenAIV1(OpenAIChatCompletionRequestV1),
60 AnthropicV1(AnthropicMessageRequestV1),
61 GeminiV1(GeminiGenerateContentRequestV1),
62}
63
64impl ProviderRequest {
65 pub fn insert_message(&mut self, message: MessageNum, idx: Option<usize>) {
66 self.messages_mut().insert(idx.unwrap_or(0), message);
67 }
68
69 pub fn push_message(&mut self, message: MessageNum) {
70 self.messages_mut().push(message);
71 }
72
73 pub fn messages(&self) -> &[MessageNum] {
74 dispatch_trait_method!(self, RequestAdapter, messages())
75 }
76
77 pub fn system_instructions(&self) -> Vec<&MessageNum> {
78 dispatch_trait_method!(self, RequestAdapter, system_instructions())
79 }
80
81 pub fn messages_mut(&mut self) -> &mut Vec<MessageNum> {
82 dispatch_trait_method!(mut self, RequestAdapter, messages_mut())
83 }
84
85 pub fn add_tools(&mut self, tools: Vec<AgentToolDefinition>) -> Result<(), TypeError> {
86 dispatch_trait_method!(mut self, RequestAdapter, add_tools(tools))
87 }
88
89 pub fn prepend_system_instructions(
90 &mut self,
91 instructions: Vec<MessageNum>,
92 ) -> Result<(), TypeError> {
93 dispatch_trait_method!(mut self, RequestAdapter, preprend_system_instructions(instructions))
94 }
95
96 pub(crate) fn get_py_messages<'py>(
98 &self,
99 py: Python<'py>,
100 ) -> Result<Bound<'py, PyList>, TypeError> {
101 let py_messages = PyList::empty(py);
102
103 for msg in self.messages() {
104 if msg.is_user_message() {
105 py_messages.append(msg.to_bound_py_object(py)?)?;
106 }
107 }
108
109 Ok(py_messages)
110 }
111
112 pub(crate) fn get_all_py_messages<'py>(
113 &self,
114 py: Python<'py>,
115 ) -> Result<Bound<'py, PyList>, TypeError> {
116 let py_messages = PyList::empty(py);
117
118 for msg in self.messages() {
119 py_messages.append(msg.to_bound_py_object(py)?)?;
120 }
121
122 Ok(py_messages)
123 }
124
125 pub(crate) fn get_py_message<'py>(
127 &self,
128 py: Python<'py>,
129 ) -> Result<Bound<'py, PyAny>, TypeError> {
130 let last = self
131 .messages()
132 .iter()
133 .rev()
134 .find(|msg| msg.is_user_message())
135 .ok_or_else(|| {
136 TypeError::Error("No messages in request to convert to Python object".to_string())
137 })?;
138
139 last.to_bound_py_object(py)
140 }
141
142 pub(crate) fn get_openai_message(&self) -> Result<ChatMessage, TypeError> {
143 let last = self
144 .messages()
145 .iter()
146 .rev()
147 .find(|msg| msg.is_user_message())
148 .ok_or_else(|| {
149 TypeError::Error("No messages in request to convert to Python object".to_string())
150 })?;
151
152 match last {
153 MessageNum::OpenAIMessageV1(msg) => Ok(msg.clone()),
154 _ => Err(TypeError::Error(
155 "Last message is not an OpenAI ChatMessage".to_string(),
156 )),
157 }
158 }
159
160 pub(crate) fn get_gemini_message(&self) -> Result<GeminiContent, TypeError> {
162 let last = self
163 .messages()
164 .iter()
165 .rev()
166 .find(|msg| msg.is_user_message())
167 .ok_or_else(|| {
168 TypeError::Error("No messages in request to convert to Python object".to_string())
169 })?;
170
171 match last {
172 MessageNum::GeminiContentV1(msg) => Ok(msg.clone()),
173 _ => Err(TypeError::Error(
174 "Last message is not a GeminiContent".to_string(),
175 )),
176 }
177 }
178
179 pub(crate) fn get_anthropic_message(&self) -> Result<MessageParam, TypeError> {
181 let last = self
182 .messages()
183 .iter()
184 .rev()
185 .find(|msg| msg.is_user_message())
186 .ok_or_else(|| {
187 TypeError::Error("No messages in request to convert to Python object".to_string())
188 })?;
189
190 Ok(match last {
191 MessageNum::AnthropicMessageV1(msg) => msg.clone(),
192 _ => {
193 return Err(TypeError::Error(
194 "Last message is not an Anthropic MessageParam".to_string(),
195 ))
196 }
197 })
198 }
199
200 pub(crate) fn get_py_system_instructions<'py>(
201 &self,
202 py: Python<'py>,
203 ) -> Result<Bound<'py, PyList>, TypeError> {
204 dispatch_trait_method!(self, RequestAdapter, get_py_system_instructions(py))
205 }
206
207 pub(crate) fn model_settings<'py>(
208 &self,
209 py: Python<'py>,
210 ) -> Result<Bound<'py, PyAny>, TypeError> {
211 dispatch_trait_method!(self, RequestAdapter, model_settings(py))
212 }
213
214 pub fn response_json_schema(&self) -> Option<&Value> {
215 dispatch_trait_method!(self, RequestAdapter, response_json_schema())
216 }
217
218 pub fn has_structured_output(&self) -> bool {
219 self.response_json_schema().is_some()
220 }
221
222 pub fn to_request(&self, provider: &Provider) -> Result<Value, TypeError> {
226 let is_matched = dispatch_trait_method!(self, RequestAdapter, match_provider(provider));
227
228 if !is_matched {
229 return Err(TypeError::Error(
230 "ProviderRequest does not match the specified provider".to_string(),
231 ));
232 }
233 dispatch_trait_method!(self, RequestAdapter, to_request_body())
234 }
235
236 pub fn to_json(&self) -> Result<Value, TypeError> {
238 Ok(serde_json::to_value(self)?)
239 }
240
241 pub fn set_response_json_schema(&mut self, response_json_schema: Option<Value>) {
242 dispatch_trait_method!(mut self, RequestAdapter, set_response_json_schema(response_json_schema))
243 }
244}
245
246pub fn to_provider_request(
247 messages: Vec<MessageNum>,
248 system_instructions: Vec<MessageNum>,
249 model: String,
250 model_settings: ModelSettings,
251 response_json_schema: Option<Value>,
252) -> Result<ProviderRequest, TypeError> {
253 let request_type = messages
255 .first()
256 .ok_or_else(|| TypeError::Error("Prompt has no messages".to_string()))?
257 .request_type();
258
259 for msg in &messages {
261 if msg.request_type() as u8 != request_type as u8 {
262 return Err(TypeError::Error(
263 "All messages must be of the same provider type".to_string(),
264 ));
265 }
266 }
267
268 match request_type {
270 RequestType::OpenAIChatV1 => OpenAIChatCompletionRequestV1::build_provider_enum(
271 messages,
272 system_instructions,
273 model,
274 model_settings,
275 response_json_schema,
276 ),
277 RequestType::AnthropicMessageV1 => AnthropicMessageRequestV1::build_provider_enum(
278 messages,
279 system_instructions,
280 model,
281 model_settings,
282 response_json_schema,
283 ),
284 RequestType::GeminiContentV1 => GeminiGenerateContentRequestV1::build_provider_enum(
285 messages,
286 system_instructions,
287 model,
288 model_settings,
289 response_json_schema,
290 ),
291 }
292}