potato_type/prompt/
types.rs

1use crate::anthropic::v1::request::MessageParam as AnthropicMessage;
2use crate::anthropic::v1::request::TextBlockParam;
3use crate::anthropic::v1::response::ResponseContentBlock;
4use crate::google::v1::generate::Candidate;
5use crate::google::v1::generate::GeminiContent;
6use crate::google::PredictResponse;
7use crate::openai::v1::chat::request::ChatMessage as OpenAIChatMessage;
8use crate::openai::v1::Choice;
9use crate::traits::MessageConversion;
10use crate::traits::PromptMessageExt;
11use crate::Provider;
12use crate::{StructuredOutput, TypeError};
13use potato_util::PyHelperFuncs;
14use pyo3::types::PyAnyMethods;
15use pyo3::{prelude::*, IntoPyObjectExt};
16use pythonize::{depythonize, pythonize};
17use schemars::JsonSchema;
18use serde::{Deserialize, Serialize};
19use serde_json::Value;
20use std::fmt::Display;
21use tracing::error;
22#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
23#[pyclass]
24pub enum Role {
25    User,
26    Assistant,
27    Developer,
28    Tool,
29    Model,
30    System,
31}
32
33#[pymethods]
34impl Role {
35    #[pyo3(name = "as_str")]
36    pub fn as_str_py(&self) -> &'static str {
37        self.as_str()
38    }
39}
40
41impl Role {
42    /// Returns the string representation of the role
43    pub const fn as_str(&self) -> &'static str {
44        match self {
45            Role::User => "user",
46            Role::Assistant => "assistant",
47            Role::Developer => "developer",
48            Role::Tool => "tool",
49            Role::Model => "model",
50            Role::System => "system",
51        }
52    }
53}
54
55impl Display for Role {
56    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
57        match self {
58            Role::User => write!(f, "user"),
59            Role::Assistant => write!(f, "assistant"),
60            Role::Developer => write!(f, "developer"),
61            Role::Tool => write!(f, "tool"),
62            Role::Model => write!(f, "model"),
63            Role::System => write!(f, "system"),
64        }
65    }
66}
67
68impl From<Role> for &str {
69    fn from(role: Role) -> Self {
70        match role {
71            Role::User => "user",
72            Role::Assistant => "assistant",
73            Role::Developer => "developer",
74            Role::Tool => "tool",
75            Role::Model => "model",
76            Role::System => "system",
77        }
78    }
79}
80
81pub trait DeserializePromptValExt: for<'de> serde::Deserialize<'de> {
82    /// Validates and deserializes a JSON value into its struct type.
83    ///
84    /// # Arguments
85    /// * `value` - The JSON value to deserialize
86    ///
87    /// # Returns
88    /// * `Result<Self, serde_json::Error>` - The deserialized value or error
89    fn model_validate_json(value: &Value) -> Result<Self, serde_json::Error> {
90        serde_json::from_value(value.clone())
91    }
92}
93
94pub fn get_pydantic_module<'py>(py: Python<'py>, module_name: &str) -> PyResult<Bound<'py, PyAny>> {
95    py.import("pydantic_ai")?.getattr(module_name)
96}
97
98/// Checks if an object is a subclass of a pydantic BaseModel. This is used when validating structured outputs
99/// # Arguments
100/// * `py` - The Python interpreter instance
101/// * `object` - The object to check
102/// # Returns
103/// A boolean indicating whether the object is a subclass of pydantic.BaseModel
104pub fn check_pydantic_model<'py>(
105    py: Python<'py>,
106    object: &Bound<'_, PyAny>,
107) -> Result<bool, TypeError> {
108    // check pydantic import. Return false if it fails
109    let pydantic = match py.import("pydantic").map_err(|e| {
110        error!("Failed to import pydantic: {}", e);
111        false
112    }) {
113        Ok(pydantic) => pydantic,
114        Err(_) => return Ok(false),
115    };
116
117    // get builtin subclass
118    let is_subclass = py.import("builtins")?.getattr("issubclass")?;
119
120    // Need to check if provided object is a basemodel
121    let basemodel = pydantic.getattr("BaseModel")?;
122    let matched = is_subclass.call1((object, basemodel))?.extract::<bool>()?;
123
124    Ok(matched)
125}
126
127/// Generate a JSON schema from a pydantic BaseModel object.
128/// # Arguments
129/// * `object` - The pydantic BaseModel object to generate the schema from.
130/// # Returns
131/// A JSON schema as a serde_json::Value.
132fn get_json_schema_from_basemodel(object: &Bound<'_, PyAny>) -> Result<Value, TypeError> {
133    // call staticmethod .model_json_schema()
134    let schema = object.getattr("model_json_schema")?.call1(())?;
135
136    let mut schema: Value = depythonize(&schema)?;
137
138    // ensure schema as additionalProperties set to false
139    if let Some(additional_properties) = schema.get_mut("additionalProperties") {
140        *additional_properties = serde_json::json!(false);
141    } else {
142        schema
143            .as_object_mut()
144            .unwrap()
145            .insert("additionalProperties".to_string(), serde_json::json!(false));
146    }
147
148    Ok(schema)
149}
150
151fn parse_pydantic_model<'py>(
152    py: Python<'py>,
153    object: &Bound<'_, PyAny>,
154) -> Result<Option<Value>, TypeError> {
155    let is_subclass = check_pydantic_model(py, object)?;
156    if is_subclass {
157        Ok(Some(get_json_schema_from_basemodel(object)?))
158    } else {
159        Ok(None)
160    }
161}
162
163pub fn check_response_type(object: &Bound<'_, PyAny>) -> Result<Option<ResponseType>, TypeError> {
164    // try calling staticmethod response_type()
165    let response_type = match object.getattr("response_type") {
166        Ok(method) => {
167            if method.is_callable() {
168                let response_type: ResponseType = method.call0()?.extract()?;
169                Some(response_type)
170            } else {
171                None
172            }
173        }
174        Err(_) => None,
175    };
176
177    Ok(response_type)
178}
179
180fn get_json_schema_from_response_type(response_type: &ResponseType) -> Result<Value, TypeError> {
181    match response_type {
182        ResponseType::Score => Ok(Score::get_structured_output_schema()),
183        _ => {
184            // If the response type is not recognized, return None
185            Err(TypeError::Error(format!(
186                "Unsupported response type: {response_type}"
187            )))
188        }
189    }
190}
191
192pub fn parse_response_to_json<'py>(
193    py: Python<'py>,
194    object: &Bound<'_, PyAny>,
195) -> Result<(ResponseType, Option<Value>), TypeError> {
196    // check if object is a pydantic model
197    let is_pydantic_model = check_pydantic_model(py, object)?;
198    if is_pydantic_model {
199        return Ok((ResponseType::Pydantic, parse_pydantic_model(py, object)?));
200    }
201
202    // check if object has response_type method
203    let response_type = check_response_type(object)?;
204    if let Some(response_type) = response_type {
205        return Ok((
206            response_type.clone(),
207            Some(get_json_schema_from_response_type(&response_type)?),
208        ));
209    }
210
211    Ok((ResponseType::Null, None))
212}
213
214#[pyclass]
215#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
216#[serde(deny_unknown_fields)] // ensure strict validation
217pub struct Score {
218    #[pyo3(get)]
219    #[schemars(range(min = 1, max = 5))]
220    pub score: i64,
221
222    #[pyo3(get)]
223    pub reason: String,
224}
225#[pymethods]
226impl Score {
227    #[staticmethod]
228    pub fn response_type() -> ResponseType {
229        ResponseType::Score
230    }
231
232    #[staticmethod]
233    pub fn model_validate_json(json_string: String) -> Result<Score, TypeError> {
234        Ok(serde_json::from_str(&json_string)?)
235    }
236
237    #[staticmethod]
238    pub fn model_json_schema<'py>(py: Python<'py>) -> Result<Bound<'py, PyAny>, TypeError> {
239        let schema = Score::get_structured_output_schema();
240        Ok(pythonize(py, &schema)?)
241    }
242
243    pub fn __str__(&self) -> String {
244        PyHelperFuncs::__str__(self)
245    }
246}
247
248impl StructuredOutput for Score {}
249
250#[pyclass]
251#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
252pub enum ResponseType {
253    Score,
254    Pydantic,
255    Null, // This is used when no response type is specified
256}
257
258impl Display for ResponseType {
259    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
260        match self {
261            ResponseType::Score => write!(f, "Score"),
262            ResponseType::Pydantic => write!(f, "Pydantic"),
263            ResponseType::Null => write!(f, "Null"),
264        }
265    }
266}
267
268// add conversion logic based on message conversion trait
269
270#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
271#[serde(untagged)]
272pub enum MessageNum {
273    OpenAIMessageV1(OpenAIChatMessage),
274    AnthropicMessageV1(AnthropicMessage),
275    GeminiContentV1(GeminiContent),
276
277    // this is a special case for Anthropic system messages
278    AnthropicSystemMessageV1(TextBlockParam),
279}
280
281impl MessageNum {
282    /// Checks if the message type matches the given provider
283    fn matches_provider(&self, provider: &Provider) -> bool {
284        matches!(
285            (self, provider),
286            (MessageNum::OpenAIMessageV1(_), Provider::OpenAI)
287                | (MessageNum::AnthropicMessageV1(_), Provider::Anthropic)
288                | (MessageNum::AnthropicSystemMessageV1(_), Provider::Anthropic)
289                | (MessageNum::GeminiContentV1(_), Provider::Google)
290                | (MessageNum::GeminiContentV1(_), Provider::Vertex)
291                | (MessageNum::GeminiContentV1(_), Provider::Gemini)
292        )
293    }
294
295    /// Converts the message to an openai message
296    /// This is only done for anthropic and gemini messages
297    /// openai message will return a failure if called on an openai message
298    /// Control flow should ensure this is only called on non-openai messages
299    fn to_openai_message(&self) -> Result<MessageNum, TypeError> {
300        match self {
301            MessageNum::AnthropicMessageV1(msg) => {
302                Ok(MessageNum::OpenAIMessageV1(msg.to_openai_message()?))
303            }
304            MessageNum::GeminiContentV1(msg) => {
305                Ok(MessageNum::OpenAIMessageV1(msg.to_openai_message()?))
306            }
307            _ => Err(TypeError::CantConvertSelf),
308        }
309    }
310
311    /// Converts to Anthropic message format
312    fn to_anthropic_message(&self) -> Result<MessageNum, TypeError> {
313        match self {
314            MessageNum::OpenAIMessageV1(msg) => {
315                Ok(MessageNum::AnthropicMessageV1(msg.to_anthropic_message()?))
316            }
317            MessageNum::GeminiContentV1(msg) => {
318                Ok(MessageNum::AnthropicMessageV1(msg.to_anthropic_message()?))
319            }
320            _ => Err(TypeError::CantConvertSelf),
321        }
322    }
323
324    /// Converts to Google Gemini message format
325    fn to_google_message(&self) -> Result<MessageNum, TypeError> {
326        match self {
327            MessageNum::OpenAIMessageV1(msg) => {
328                Ok(MessageNum::GeminiContentV1(msg.to_google_message()?))
329            }
330            MessageNum::AnthropicMessageV1(msg) => {
331                Ok(MessageNum::GeminiContentV1(msg.to_google_message()?))
332            }
333            _ => Err(TypeError::CantConvertSelf),
334        }
335    }
336
337    fn convert_message_to_provider_type(
338        &self,
339        provider: &Provider,
340    ) -> Result<MessageNum, TypeError> {
341        match provider {
342            Provider::OpenAI => self.to_openai_message(),
343            Provider::Anthropic => self.to_anthropic_message(),
344            Provider::Google => self.to_google_message(),
345            Provider::Vertex => self.to_google_message(),
346            Provider::Gemini => self.to_google_message(),
347            _ => Err(TypeError::UnsupportedProviderError),
348        }
349    }
350
351    pub fn convert_message(&mut self, provider: &Provider) -> Result<(), TypeError> {
352        // if message already matches provider, return Ok
353        if self.matches_provider(provider) {
354            return Ok(());
355        }
356        let converted = self.convert_message_to_provider_type(provider)?;
357        *self = converted;
358        Ok(())
359    }
360
361    pub fn anthropic_message_to_system_message(&mut self) -> Result<(), TypeError> {
362        match self {
363            MessageNum::AnthropicMessageV1(msg) => {
364                let text_param = msg.to_text_block_param()?;
365                *self = MessageNum::AnthropicSystemMessageV1(text_param);
366                Ok(())
367            }
368            _ => Err(TypeError::Error(
369                "Cannot convert non-AnthropicMessageV1 to system message".to_string(),
370            )),
371        }
372    }
373    pub fn role(&self) -> &str {
374        match self {
375            MessageNum::OpenAIMessageV1(msg) => &msg.role,
376            MessageNum::AnthropicMessageV1(msg) => &msg.role,
377            MessageNum::GeminiContentV1(msg) => &msg.role,
378            _ => "system",
379        }
380    }
381    pub fn bind(&self, name: &str, value: &str) -> Result<Self, TypeError> {
382        match self {
383            MessageNum::OpenAIMessageV1(msg) => {
384                let bound_msg = msg.bind(name, value)?;
385                Ok(MessageNum::OpenAIMessageV1(bound_msg))
386            }
387            MessageNum::AnthropicMessageV1(msg) => {
388                let bound_msg = msg.bind(name, value)?;
389                Ok(MessageNum::AnthropicMessageV1(bound_msg))
390            }
391            MessageNum::GeminiContentV1(msg) => {
392                let bound_msg = msg.bind(name, value)?;
393                Ok(MessageNum::GeminiContentV1(bound_msg))
394            }
395            _ => Ok(self.clone()),
396        }
397    }
398    pub fn bind_mut(&mut self, name: &str, value: &str) -> Result<(), TypeError> {
399        match self {
400            MessageNum::OpenAIMessageV1(msg) => msg.bind_mut(name, value),
401            MessageNum::AnthropicMessageV1(msg) => msg.bind_mut(name, value),
402            MessageNum::GeminiContentV1(msg) => msg.bind_mut(name, value),
403            _ => Ok(()),
404        }
405    }
406
407    pub(crate) fn extract_variables(&self) -> Vec<String> {
408        match self {
409            MessageNum::OpenAIMessageV1(msg) => msg.extract_variables(),
410            MessageNum::AnthropicMessageV1(msg) => msg.extract_variables(),
411            MessageNum::GeminiContentV1(msg) => msg.extract_variables(),
412            _ => vec![],
413        }
414    }
415
416    pub fn to_bound_py_object<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyAny>, TypeError> {
417        match self {
418            MessageNum::OpenAIMessageV1(msg) => {
419                let bound_msg = msg.clone().into_bound_py_any(py)?;
420                Ok(bound_msg)
421            }
422            MessageNum::AnthropicMessageV1(msg) => {
423                let bound_msg = msg.clone().into_bound_py_any(py)?;
424                Ok(bound_msg)
425            }
426            MessageNum::GeminiContentV1(msg) => {
427                let bound_msg = msg.clone().into_bound_py_any(py)?;
428                Ok(bound_msg)
429            }
430            MessageNum::AnthropicSystemMessageV1(msg) => {
431                // Convert to AnthropicMessage first
432                let anthropic_msg = AnthropicMessage::from_text(msg.text.clone(), self.role())?;
433                let bound_msg = anthropic_msg.into_bound_py_any(py)?;
434                Ok(bound_msg)
435            }
436        }
437    }
438
439    pub fn to_bound_openai_message<'py>(
440        &self,
441        py: Python<'py>,
442    ) -> Result<Bound<'py, OpenAIChatMessage>, TypeError> {
443        match self {
444            MessageNum::OpenAIMessageV1(msg) => {
445                let py_obj = Py::new(py, msg.clone())?;
446                let bound = py_obj.bind(py);
447                Ok(bound.clone())
448            }
449            _ => Err(TypeError::CantConvertSelf),
450        }
451    }
452
453    pub fn to_bound_gemini_message<'py>(
454        &self,
455        py: Python<'py>,
456    ) -> Result<Bound<'py, GeminiContent>, TypeError> {
457        match self {
458            MessageNum::GeminiContentV1(msg) => {
459                let py_obj = Py::new(py, msg.clone())?;
460                let bound = py_obj.bind(py);
461                Ok(bound.clone())
462            }
463            _ => Err(TypeError::CantConvertSelf),
464        }
465    }
466
467    pub fn to_bound_anthropic_message<'py>(
468        &self,
469        py: Python<'py>,
470    ) -> Result<Bound<'py, AnthropicMessage>, TypeError> {
471        match self {
472            MessageNum::AnthropicMessageV1(msg) => {
473                let py_obj = Py::new(py, msg.clone())?;
474                let bound = py_obj.bind(py);
475                Ok(bound.clone())
476            }
477            _ => Err(TypeError::CantConvertSelf),
478        }
479    }
480
481    pub fn is_system_message(&self) -> bool {
482        match self {
483            MessageNum::OpenAIMessageV1(msg) => {
484                msg.role == Role::Developer.to_string() || msg.role == Role::System.to_string()
485            }
486            MessageNum::AnthropicMessageV1(msg) => msg.role == Role::System.to_string(),
487            MessageNum::GeminiContentV1(msg) => msg.role == Role::Model.to_string(),
488            MessageNum::AnthropicSystemMessageV1(_) => true,
489        }
490    }
491
492    pub fn is_user_message(&self) -> bool {
493        match self {
494            MessageNum::OpenAIMessageV1(msg) => msg.role == Role::User.to_string(),
495            MessageNum::AnthropicMessageV1(msg) => msg.role == Role::User.to_string(),
496            MessageNum::GeminiContentV1(msg) => msg.role == Role::User.to_string(),
497            MessageNum::AnthropicSystemMessageV1(_) => false,
498        }
499    }
500}
501
502#[derive(Debug, Clone)]
503pub enum ResponseContent {
504    OpenAI(Choice),
505    Google(Candidate),
506    Anthropic(ResponseContentBlock),
507    PredictResponse(PredictResponse),
508}
509
510#[pyclass]
511pub struct OpenAIMessageList {
512    pub messages: Vec<OpenAIChatMessage>,
513}
514
515#[pymethods]
516impl OpenAIMessageList {
517    fn __iter__(slf: PyRef<'_, Self>) -> PyResult<Py<OpenAIMessageIterator>> {
518        let iter = OpenAIMessageIterator {
519            inner: slf.messages.clone().into_iter(),
520        };
521        Py::new(slf.py(), iter)
522    }
523
524    pub fn __len__(&self) -> usize {
525        self.messages.len()
526    }
527
528    pub fn __getitem__(&self, index: isize) -> Result<OpenAIChatMessage, TypeError> {
529        let len = self.messages.len() as isize;
530        let normalized_index = if index < 0 { len + index } else { index };
531
532        if normalized_index < 0 || normalized_index >= len {
533            return Err(TypeError::Error(format!(
534                "Index {} out of range for list of length {}",
535                index, len
536            )));
537        }
538
539        Ok(self.messages[normalized_index as usize].clone())
540    }
541
542    pub fn __str__(&self) -> String {
543        PyHelperFuncs::__str__(&self.messages)
544    }
545
546    pub fn __repr__(&self) -> String {
547        self.__str__()
548    }
549}
550
551#[pyclass]
552pub struct OpenAIMessageIterator {
553    inner: std::vec::IntoIter<OpenAIChatMessage>,
554}
555
556#[pymethods]
557impl OpenAIMessageIterator {
558    fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> {
559        slf
560    }
561
562    fn __next__(mut slf: PyRefMut<'_, Self>) -> Option<OpenAIChatMessage> {
563        slf.inner.next()
564    }
565}
566
567#[pyclass]
568pub struct AnthropicMessageList {
569    pub messages: Vec<AnthropicMessage>,
570}
571
572#[pymethods]
573impl AnthropicMessageList {
574    fn __iter__(slf: PyRef<'_, Self>) -> PyResult<Py<AnthropicMessageIterator>> {
575        let iter = AnthropicMessageIterator {
576            inner: slf.messages.clone().into_iter(),
577        };
578        Py::new(slf.py(), iter)
579    }
580
581    pub fn __len__(&self) -> usize {
582        self.messages.len()
583    }
584
585    pub fn __getitem__(&self, index: isize) -> Result<AnthropicMessage, TypeError> {
586        let len = self.messages.len() as isize;
587        let normalized_index = if index < 0 { len + index } else { index };
588
589        if normalized_index < 0 || normalized_index >= len {
590            return Err(TypeError::Error(format!(
591                "Index {} out of range for list of length {}",
592                index, len
593            )));
594        }
595
596        Ok(self.messages[normalized_index as usize].clone())
597    }
598
599    pub fn __str__(&self) -> String {
600        PyHelperFuncs::__str__(&self.messages)
601    }
602
603    pub fn __repr__(&self) -> String {
604        self.__str__()
605    }
606}
607
608#[pyclass]
609pub struct AnthropicMessageIterator {
610    inner: std::vec::IntoIter<AnthropicMessage>,
611}
612
613#[pymethods]
614impl AnthropicMessageIterator {
615    fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> {
616        slf
617    }
618
619    fn __next__(mut slf: PyRefMut<'_, Self>) -> Option<AnthropicMessage> {
620        slf.inner.next()
621    }
622}
623
624#[pyclass]
625pub struct GeminiContentList {
626    pub messages: Vec<GeminiContent>,
627}
628
629#[pymethods]
630impl GeminiContentList {
631    fn __iter__(slf: PyRef<'_, Self>) -> PyResult<Py<GeminiContentIterator>> {
632        let iter = GeminiContentIterator {
633            inner: slf.messages.clone().into_iter(),
634        };
635        Py::new(slf.py(), iter)
636    }
637
638    pub fn __len__(&self) -> usize {
639        self.messages.len()
640    }
641
642    pub fn __getitem__(&self, index: isize) -> Result<GeminiContent, TypeError> {
643        let len = self.messages.len() as isize;
644        let normalized_index = if index < 0 { len + index } else { index };
645
646        if normalized_index < 0 || normalized_index >= len {
647            return Err(TypeError::Error(format!(
648                "Index {} out of range for list of length {}",
649                index, len
650            )));
651        }
652
653        Ok(self.messages[normalized_index as usize].clone())
654    }
655
656    pub fn __str__(&self) -> String {
657        PyHelperFuncs::__str__(&self.messages)
658    }
659
660    pub fn __repr__(&self) -> String {
661        self.__str__()
662    }
663}
664
665#[pyclass]
666pub struct GeminiContentIterator {
667    inner: std::vec::IntoIter<GeminiContent>,
668}
669
670#[pymethods]
671impl GeminiContentIterator {
672    fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> {
673        slf
674    }
675
676    fn __next__(mut slf: PyRefMut<'_, Self>) -> Option<GeminiContent> {
677        slf.inner.next()
678    }
679}