Skip to main content

potato_type/prompt/
builder.rs

1use crate::anthropic::v1::request::AnthropicMessageRequestV1;
2use crate::anthropic::MessageParam;
3use crate::google::v1::generate::request::GeminiGenerateContentRequestV1;
4use crate::google::GeminiContent;
5use crate::openai::v1::chat::request::OpenAIChatCompletionRequestV1;
6use crate::openai::ChatMessage;
7use crate::prompt::types::MessageNum;
8use crate::prompt::ModelSettings;
9use crate::tools::AgentToolDefinition;
10use crate::traits::RequestAdapter;
11use crate::{Provider, TypeError};
12use potatohead_macro::dispatch_trait_method;
13use pyo3::types::PyList;
14use pyo3::types::PyListMethods;
15use pyo3::Python;
16use pyo3::{Bound, PyAny};
17use serde::{Deserialize, Serialize};
18use serde_json::Value;
19
20/// Trait for converting a list of messages into a provider-specific request
21pub trait RequestBuilder {
22    type Request: Serialize;
23
24    /// Build a request from messages and settings
25    fn build_request(
26        messages: Vec<MessageNum>,
27        system_instructions: Vec<MessageNum>,
28        model: String,
29        settings: ModelSettings,
30        response_format: Option<Value>,
31    ) -> Result<Self::Request, TypeError>;
32}
33
34/// Type marker for request routing
35#[derive(Debug, Clone, Copy)]
36pub enum RequestType {
37    OpenAIChatV1,
38    AnthropicMessageV1,
39    GeminiContentV1,
40}
41
42impl MessageNum {
43    /// Determine the request type from this message variant
44    pub fn request_type(&self) -> RequestType {
45        match self {
46            MessageNum::OpenAIMessageV1(_) => RequestType::OpenAIChatV1,
47            MessageNum::AnthropicMessageV1(_) => RequestType::AnthropicMessageV1,
48            MessageNum::GeminiContentV1(_) => RequestType::GeminiContentV1,
49            MessageNum::AnthropicSystemMessageV1(_) => RequestType::AnthropicMessageV1,
50        }
51    }
52}
53
54/// Unified enum for provider-specific requests
55/// This serves as a central access point for accessing request attributes from within a Prompt
56#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
57#[serde(untagged)]
58pub enum ProviderRequest {
59    OpenAIV1(OpenAIChatCompletionRequestV1),
60    AnthropicV1(AnthropicMessageRequestV1),
61    GeminiV1(GeminiGenerateContentRequestV1),
62}
63
64impl ProviderRequest {
65    pub fn insert_message(&mut self, message: MessageNum, idx: Option<usize>) {
66        self.messages_mut().insert(idx.unwrap_or(0), message);
67    }
68
69    pub fn push_message(&mut self, message: MessageNum) {
70        self.messages_mut().push(message);
71    }
72
73    pub fn messages(&self) -> &[MessageNum] {
74        dispatch_trait_method!(self, RequestAdapter, messages())
75    }
76
77    pub fn system_instructions(&self) -> Vec<&MessageNum> {
78        dispatch_trait_method!(self, RequestAdapter, system_instructions())
79    }
80
81    pub fn messages_mut(&mut self) -> &mut Vec<MessageNum> {
82        dispatch_trait_method!(mut self, RequestAdapter, messages_mut())
83    }
84
85    pub fn add_tools(&mut self, tools: Vec<AgentToolDefinition>) -> Result<(), TypeError> {
86        dispatch_trait_method!(mut self, RequestAdapter, add_tools(tools))
87    }
88
89    pub fn prepend_system_instructions(
90        &mut self,
91        instructions: Vec<MessageNum>,
92    ) -> Result<(), TypeError> {
93        dispatch_trait_method!(mut self, RequestAdapter, preprend_system_instructions(instructions))
94    }
95
96    /// Returns the messages as a Python list
97    pub(crate) fn get_py_messages<'py>(
98        &self,
99        py: Python<'py>,
100    ) -> Result<Bound<'py, PyList>, TypeError> {
101        let py_messages = PyList::empty(py);
102
103        for msg in self.messages() {
104            if msg.is_user_message() {
105                py_messages.append(msg.to_bound_py_object(py)?)?;
106            }
107        }
108
109        Ok(py_messages)
110    }
111
112    pub(crate) fn get_all_py_messages<'py>(
113        &self,
114        py: Python<'py>,
115    ) -> Result<Bound<'py, PyList>, TypeError> {
116        let py_messages = PyList::empty(py);
117
118        for msg in self.messages() {
119            py_messages.append(msg.to_bound_py_object(py)?)?;
120        }
121
122        Ok(py_messages)
123    }
124
125    /// Returns the last message in the request as a Python object
126    pub(crate) fn get_py_message<'py>(
127        &self,
128        py: Python<'py>,
129    ) -> Result<Bound<'py, PyAny>, TypeError> {
130        let last = self
131            .messages()
132            .iter()
133            .rev()
134            .find(|msg| msg.is_user_message())
135            .ok_or_else(|| {
136                TypeError::Error("No messages in request to convert to Python object".to_string())
137            })?;
138
139        last.to_bound_py_object(py)
140    }
141
142    pub(crate) fn get_openai_message(&self) -> Result<ChatMessage, TypeError> {
143        let last = self
144            .messages()
145            .iter()
146            .rev()
147            .find(|msg| msg.is_user_message())
148            .ok_or_else(|| {
149                TypeError::Error("No messages in request to convert to Python object".to_string())
150            })?;
151
152        match last {
153            MessageNum::OpenAIMessageV1(msg) => Ok(msg.clone()),
154            _ => Err(TypeError::Error(
155                "Last message is not an OpenAI ChatMessage".to_string(),
156            )),
157        }
158    }
159
160    /// Returns the messages as Anthropic MessageParam Python objects
161    pub(crate) fn get_gemini_message(&self) -> Result<GeminiContent, TypeError> {
162        let last = self
163            .messages()
164            .iter()
165            .rev()
166            .find(|msg| msg.is_user_message())
167            .ok_or_else(|| {
168                TypeError::Error("No messages in request to convert to Python object".to_string())
169            })?;
170
171        match last {
172            MessageNum::GeminiContentV1(msg) => Ok(msg.clone()),
173            _ => Err(TypeError::Error(
174                "Last message is not a GeminiContent".to_string(),
175            )),
176        }
177    }
178
179    /// Returns the messages as Anthropic MessageParam Python objects
180    pub(crate) fn get_anthropic_message(&self) -> Result<MessageParam, TypeError> {
181        let last = self
182            .messages()
183            .iter()
184            .rev()
185            .find(|msg| msg.is_user_message())
186            .ok_or_else(|| {
187                TypeError::Error("No messages in request to convert to Python object".to_string())
188            })?;
189
190        Ok(match last {
191            MessageNum::AnthropicMessageV1(msg) => msg.clone(),
192            _ => {
193                return Err(TypeError::Error(
194                    "Last message is not an Anthropic MessageParam".to_string(),
195                ))
196            }
197        })
198    }
199
200    pub(crate) fn get_py_system_instructions<'py>(
201        &self,
202        py: Python<'py>,
203    ) -> Result<Bound<'py, PyList>, TypeError> {
204        dispatch_trait_method!(self, RequestAdapter, get_py_system_instructions(py))
205    }
206
207    pub(crate) fn model_settings<'py>(
208        &self,
209        py: Python<'py>,
210    ) -> Result<Bound<'py, PyAny>, TypeError> {
211        dispatch_trait_method!(self, RequestAdapter, model_settings(py))
212    }
213
214    pub fn response_json_schema(&self) -> Option<&Value> {
215        dispatch_trait_method!(self, RequestAdapter, response_json_schema())
216    }
217
218    pub fn has_structured_output(&self) -> bool {
219        self.response_json_schema().is_some()
220    }
221
222    /// Retrieve the JSON request body for the specified provider
223    /// This method will first attempt to match the provider type,
224    /// returning an error if there is a mismatch.
225    pub fn to_request(&self, provider: &Provider) -> Result<Value, TypeError> {
226        let is_matched = dispatch_trait_method!(self, RequestAdapter, match_provider(provider));
227
228        if !is_matched {
229            return Err(TypeError::Error(
230                "ProviderRequest does not match the specified provider".to_string(),
231            ));
232        }
233        dispatch_trait_method!(self, RequestAdapter, to_request_body())
234    }
235
236    /// Serialize to JSON for API requests
237    pub fn to_json(&self) -> Result<Value, TypeError> {
238        Ok(serde_json::to_value(self)?)
239    }
240
241    pub fn set_response_json_schema(&mut self, response_json_schema: Option<Value>) {
242        dispatch_trait_method!(mut self, RequestAdapter, set_response_json_schema(response_json_schema))
243    }
244}
245
246pub fn to_provider_request(
247    messages: Vec<MessageNum>,
248    system_instructions: Vec<MessageNum>,
249    model: String,
250    model_settings: ModelSettings,
251    response_json_schema: Option<Value>,
252) -> Result<ProviderRequest, TypeError> {
253    // Determine request type from first message
254    let request_type = messages
255        .first()
256        .ok_or_else(|| TypeError::Error("Prompt has no messages".to_string()))?
257        .request_type();
258
259    // Validate all messages are same type
260    for msg in &messages {
261        if msg.request_type() as u8 != request_type as u8 {
262            return Err(TypeError::Error(
263                "All messages must be of the same provider type".to_string(),
264            ));
265        }
266    }
267
268    // Build appropriate request based on type
269    match request_type {
270        RequestType::OpenAIChatV1 => OpenAIChatCompletionRequestV1::build_provider_enum(
271            messages,
272            system_instructions,
273            model,
274            model_settings,
275            response_json_schema,
276        ),
277        RequestType::AnthropicMessageV1 => AnthropicMessageRequestV1::build_provider_enum(
278            messages,
279            system_instructions,
280            model,
281            model_settings,
282            response_json_schema,
283        ),
284        RequestType::GeminiContentV1 => GeminiGenerateContentRequestV1::build_provider_enum(
285            messages,
286            system_instructions,
287            model,
288            model_settings,
289            response_json_schema,
290        ),
291    }
292}