Skip to main content

potato_type/prompt/
types.rs

1use crate::anthropic::v1::request::MessageParam as AnthropicMessage;
2use crate::anthropic::v1::request::TextBlockParam;
3use crate::anthropic::v1::response::ResponseContentBlock;
4use crate::google::v1::generate::Candidate;
5use crate::google::v1::generate::GeminiContent;
6use crate::google::PredictResponse;
7use crate::openai::v1::chat::request::ChatMessage as OpenAIChatMessage;
8use crate::openai::v1::Choice;
9use crate::traits::MessageConversion;
10use crate::traits::PromptMessageExt;
11use crate::Provider;
12use crate::{StructuredOutput, TypeError};
13use potato_util::PyHelperFuncs;
14use pyo3::types::PyAnyMethods;
15use pyo3::{prelude::*, IntoPyObjectExt};
16use pythonize::{depythonize, pythonize};
17use schemars::JsonSchema;
18use serde::{Deserialize, Serialize};
19use serde_json::Value;
20use std::fmt::Display;
21use tracing::error;
22#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
23#[pyclass]
24pub enum Role {
25    User,
26    Assistant,
27    Developer,
28    Tool,
29    Model,
30    System,
31}
32
33#[pymethods]
34impl Role {
35    #[pyo3(name = "as_str")]
36    pub fn as_str_py(&self) -> &'static str {
37        self.as_str()
38    }
39}
40
41impl Role {
42    /// Returns the string representation of the role
43    pub const fn as_str(&self) -> &'static str {
44        match self {
45            Role::User => "user",
46            Role::Assistant => "assistant",
47            Role::Developer => "developer",
48            Role::Tool => "tool",
49            Role::Model => "model",
50            Role::System => "system",
51        }
52    }
53}
54
55impl Display for Role {
56    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
57        match self {
58            Role::User => write!(f, "user"),
59            Role::Assistant => write!(f, "assistant"),
60            Role::Developer => write!(f, "developer"),
61            Role::Tool => write!(f, "tool"),
62            Role::Model => write!(f, "model"),
63            Role::System => write!(f, "system"),
64        }
65    }
66}
67
68impl From<Role> for &str {
69    fn from(role: Role) -> Self {
70        match role {
71            Role::User => "user",
72            Role::Assistant => "assistant",
73            Role::Developer => "developer",
74            Role::Tool => "tool",
75            Role::Model => "model",
76            Role::System => "system",
77        }
78    }
79}
80
81pub trait DeserializePromptValExt: for<'de> serde::Deserialize<'de> {
82    /// Validates and deserializes a JSON value into its struct type.
83    ///
84    /// # Arguments
85    /// * `value` - The JSON value to deserialize
86    ///
87    /// # Returns
88    /// * `Result<Self, serde_json::Error>` - The deserialized value or error
89    fn model_validate_json(value: &Value) -> Result<Self, serde_json::Error> {
90        serde_json::from_value(value.clone())
91    }
92}
93
94pub fn get_pydantic_module<'py>(py: Python<'py>, module_name: &str) -> PyResult<Bound<'py, PyAny>> {
95    py.import("pydantic_ai")?.getattr(module_name)
96}
97
98/// Checks if an object is a subclass of a pydantic BaseModel. This is used when validating structured outputs
99/// # Arguments
100/// * `py` - The Python interpreter instance
101/// * `object` - The object to check
102/// # Returns
103/// A boolean indicating whether the object is a subclass of pydantic.BaseModel
104pub fn check_pydantic_model<'py>(
105    py: Python<'py>,
106    object: &Bound<'_, PyAny>,
107) -> Result<bool, TypeError> {
108    // check pydantic import. Return false if it fails
109    let pydantic = match py.import("pydantic").map_err(|e| {
110        error!("Failed to import pydantic: {}", e);
111        false
112    }) {
113        Ok(pydantic) => pydantic,
114        Err(_) => return Ok(false),
115    };
116
117    // get builtin subclass
118    let is_subclass = py.import("builtins")?.getattr("issubclass")?;
119
120    // Need to check if provided object is a basemodel
121    let basemodel = pydantic.getattr("BaseModel")?;
122    let matched = is_subclass.call1((object, basemodel))?.extract::<bool>()?;
123
124    Ok(matched)
125}
126
127/// Generate a JSON schema from a pydantic BaseModel object.
128/// # Arguments
129/// * `object` - The pydantic BaseModel object to generate the schema from.
130/// # Returns
131/// A JSON schema as a serde_json::Value.
132fn get_json_schema_from_basemodel(object: &Bound<'_, PyAny>) -> Result<Value, TypeError> {
133    // call staticmethod .model_json_schema()
134    let schema = object.getattr("model_json_schema")?.call1(())?;
135
136    let mut schema: Value = depythonize(&schema)?;
137
138    // ensure schema as additionalProperties set to false
139    if let Some(additional_properties) = schema.get_mut("additionalProperties") {
140        *additional_properties = serde_json::json!(false);
141    } else {
142        schema
143            .as_object_mut()
144            .unwrap()
145            .insert("additionalProperties".to_string(), serde_json::json!(false));
146    }
147
148    Ok(schema)
149}
150
151fn parse_pydantic_model<'py>(
152    py: Python<'py>,
153    object: &Bound<'_, PyAny>,
154) -> Result<Option<Value>, TypeError> {
155    let is_subclass = check_pydantic_model(py, object)?;
156    if is_subclass {
157        Ok(Some(get_json_schema_from_basemodel(object)?))
158    } else {
159        Ok(None)
160    }
161}
162
163pub fn check_response_type(object: &Bound<'_, PyAny>) -> Result<Option<ResponseType>, TypeError> {
164    // try calling staticmethod response_type()
165    let response_type = match object.getattr("response_type") {
166        Ok(method) => {
167            if method.is_callable() {
168                let response_type: ResponseType = method.call0()?.extract()?;
169                Some(response_type)
170            } else {
171                None
172            }
173        }
174        Err(_) => None,
175    };
176
177    Ok(response_type)
178}
179
180fn get_json_schema_from_response_type(response_type: &ResponseType) -> Result<Value, TypeError> {
181    match response_type {
182        ResponseType::Score => Ok(Score::get_structured_output_schema()),
183        _ => {
184            // If the response type is not recognized, return None
185            Err(TypeError::Error(format!(
186                "Unsupported response type: {response_type}"
187            )))
188        }
189    }
190}
191
192pub fn parse_response_to_json<'py>(
193    py: Python<'py>,
194    object: &Bound<'_, PyAny>,
195) -> Result<(ResponseType, Option<Value>), TypeError> {
196    // check if object is a pydantic model
197    let is_pydantic_model = check_pydantic_model(py, object)?;
198    if is_pydantic_model {
199        return Ok((ResponseType::Pydantic, parse_pydantic_model(py, object)?));
200    }
201
202    // check if object has response_type method
203    let response_type = check_response_type(object)?;
204    if let Some(response_type) = response_type {
205        return Ok((
206            response_type.clone(),
207            Some(get_json_schema_from_response_type(&response_type)?),
208        ));
209    }
210
211    Ok((ResponseType::Null, None))
212}
213
214#[pyclass]
215#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
216#[serde(deny_unknown_fields)] // ensure strict validation
217pub struct Score {
218    #[pyo3(get)]
219    #[schemars(range(min = 1, max = 5))]
220    pub score: i64,
221
222    #[pyo3(get)]
223    pub reason: String,
224}
225#[pymethods]
226impl Score {
227    #[staticmethod]
228    pub fn response_type() -> ResponseType {
229        ResponseType::Score
230    }
231
232    #[staticmethod]
233    pub fn model_validate_json(json_string: String) -> Result<Score, TypeError> {
234        Ok(serde_json::from_str(&json_string)?)
235    }
236
237    #[staticmethod]
238    pub fn model_json_schema<'py>(py: Python<'py>) -> Result<Bound<'py, PyAny>, TypeError> {
239        let schema = Score::get_structured_output_schema();
240        Ok(pythonize(py, &schema)?)
241    }
242
243    pub fn __str__(&self) -> String {
244        PyHelperFuncs::__str__(self)
245    }
246}
247
248impl StructuredOutput for Score {}
249
250#[pyclass]
251#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
252pub enum ResponseType {
253    Score,
254    Pydantic,
255    #[default]
256    Null, // This is used when no response type is specified
257}
258
259impl Display for ResponseType {
260    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
261        match self {
262            ResponseType::Score => write!(f, "Score"),
263            ResponseType::Pydantic => write!(f, "Pydantic"),
264            ResponseType::Null => write!(f, "Null"),
265        }
266    }
267}
268
269// add conversion logic based on message conversion trait
270
271#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
272#[serde(untagged)]
273pub enum MessageNum {
274    OpenAIMessageV1(OpenAIChatMessage),
275    AnthropicMessageV1(AnthropicMessage),
276    GeminiContentV1(GeminiContent),
277
278    // this is a special case for Anthropic system messages
279    AnthropicSystemMessageV1(TextBlockParam),
280
281    // Raw JSON message - used internally in the agentic loop for provider-specific
282    // messages that don't fit the typed variants (e.g., OpenAI tool-call assistant messages).
283    // Must be LAST so serde tries typed variants first when deserializing.
284    RawV1(serde_json::Value),
285}
286
287impl MessageNum {
288    /// Checks if the message type matches the given provider
289    fn matches_provider(&self, provider: &Provider) -> bool {
290        matches!(
291            (self, provider),
292            (MessageNum::OpenAIMessageV1(_), Provider::OpenAI)
293                | (MessageNum::AnthropicMessageV1(_), Provider::Anthropic)
294                | (MessageNum::AnthropicSystemMessageV1(_), Provider::Anthropic)
295                | (MessageNum::GeminiContentV1(_), Provider::Google)
296                | (MessageNum::GeminiContentV1(_), Provider::Vertex)
297                | (MessageNum::GeminiContentV1(_), Provider::Gemini)
298        )
299    }
300
301    /// Converts the message to an openai message
302    /// This is only done for anthropic and gemini messages
303    /// openai message will return a failure if called on an openai message
304    /// Control flow should ensure this is only called on non-openai messages
305    fn to_openai_message(&self) -> Result<MessageNum, TypeError> {
306        match self {
307            MessageNum::AnthropicMessageV1(msg) => {
308                Ok(MessageNum::OpenAIMessageV1(msg.to_openai_message()?))
309            }
310            MessageNum::GeminiContentV1(msg) => {
311                Ok(MessageNum::OpenAIMessageV1(msg.to_openai_message()?))
312            }
313            _ => Err(TypeError::CantConvertSelf),
314        }
315    }
316
317    /// Converts to Anthropic message format
318    fn to_anthropic_message(&self) -> Result<MessageNum, TypeError> {
319        match self {
320            MessageNum::OpenAIMessageV1(msg) => {
321                Ok(MessageNum::AnthropicMessageV1(msg.to_anthropic_message()?))
322            }
323            MessageNum::GeminiContentV1(msg) => {
324                Ok(MessageNum::AnthropicMessageV1(msg.to_anthropic_message()?))
325            }
326            _ => Err(TypeError::CantConvertSelf),
327        }
328    }
329
330    /// Converts to Google Gemini message format
331    fn to_google_message(&self) -> Result<MessageNum, TypeError> {
332        match self {
333            MessageNum::OpenAIMessageV1(msg) => {
334                Ok(MessageNum::GeminiContentV1(msg.to_google_message()?))
335            }
336            MessageNum::AnthropicMessageV1(msg) => {
337                Ok(MessageNum::GeminiContentV1(msg.to_google_message()?))
338            }
339            _ => Err(TypeError::CantConvertSelf),
340        }
341    }
342
343    fn convert_message_to_provider_type(
344        &self,
345        provider: &Provider,
346    ) -> Result<MessageNum, TypeError> {
347        match provider {
348            Provider::OpenAI => self.to_openai_message(),
349            Provider::Anthropic => self.to_anthropic_message(),
350            Provider::Google => self.to_google_message(),
351            Provider::Vertex => self.to_google_message(),
352            Provider::Gemini => self.to_google_message(),
353            _ => Err(TypeError::UnsupportedProviderError),
354        }
355    }
356
357    pub fn convert_message(&mut self, provider: &Provider) -> Result<(), TypeError> {
358        // RawV1 is provider-specific raw JSON — skip conversion
359        if matches!(self, MessageNum::RawV1(_)) {
360            return Ok(());
361        }
362        // if message already matches provider, return Ok
363        if self.matches_provider(provider) {
364            return Ok(());
365        }
366        let converted = self.convert_message_to_provider_type(provider)?;
367        *self = converted;
368        Ok(())
369    }
370
371    pub fn anthropic_message_to_system_message(&mut self) -> Result<(), TypeError> {
372        match self {
373            MessageNum::AnthropicMessageV1(msg) => {
374                let text_param = msg.to_text_block_param()?;
375                *self = MessageNum::AnthropicSystemMessageV1(text_param);
376                Ok(())
377            }
378            _ => Err(TypeError::Error(
379                "Cannot convert non-AnthropicMessageV1 to system message".to_string(),
380            )),
381        }
382    }
383    pub fn role(&self) -> &str {
384        match self {
385            MessageNum::OpenAIMessageV1(msg) => &msg.role,
386            MessageNum::AnthropicMessageV1(msg) => &msg.role,
387            MessageNum::GeminiContentV1(msg) => &msg.role,
388            _ => "system",
389        }
390    }
391    pub fn bind(&self, name: &str, value: &str) -> Result<Self, TypeError> {
392        match self {
393            MessageNum::OpenAIMessageV1(msg) => {
394                let bound_msg = msg.bind(name, value)?;
395                Ok(MessageNum::OpenAIMessageV1(bound_msg))
396            }
397            MessageNum::AnthropicMessageV1(msg) => {
398                let bound_msg = msg.bind(name, value)?;
399                Ok(MessageNum::AnthropicMessageV1(bound_msg))
400            }
401            MessageNum::GeminiContentV1(msg) => {
402                let bound_msg = msg.bind(name, value)?;
403                Ok(MessageNum::GeminiContentV1(bound_msg))
404            }
405            _ => Ok(self.clone()),
406        }
407    }
408    pub fn bind_mut(&mut self, name: &str, value: &str) -> Result<(), TypeError> {
409        match self {
410            MessageNum::OpenAIMessageV1(msg) => msg.bind_mut(name, value),
411            MessageNum::AnthropicMessageV1(msg) => msg.bind_mut(name, value),
412            MessageNum::GeminiContentV1(msg) => msg.bind_mut(name, value),
413            _ => Ok(()),
414        }
415    }
416
417    pub(crate) fn extract_variables(&self) -> Vec<String> {
418        match self {
419            MessageNum::OpenAIMessageV1(msg) => msg.extract_variables(),
420            MessageNum::AnthropicMessageV1(msg) => msg.extract_variables(),
421            MessageNum::GeminiContentV1(msg) => msg.extract_variables(),
422            _ => vec![],
423        }
424    }
425
426    pub fn to_bound_py_object<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyAny>, TypeError> {
427        match self {
428            MessageNum::OpenAIMessageV1(msg) => {
429                let bound_msg = msg.clone().into_bound_py_any(py)?;
430                Ok(bound_msg)
431            }
432            MessageNum::AnthropicMessageV1(msg) => {
433                let bound_msg = msg.clone().into_bound_py_any(py)?;
434                Ok(bound_msg)
435            }
436            MessageNum::GeminiContentV1(msg) => {
437                let bound_msg = msg.clone().into_bound_py_any(py)?;
438                Ok(bound_msg)
439            }
440            MessageNum::AnthropicSystemMessageV1(msg) => {
441                // Convert to AnthropicMessage first
442                let anthropic_msg = AnthropicMessage::from_text(msg.text.clone(), self.role())?;
443                let bound_msg = anthropic_msg.into_bound_py_any(py)?;
444                Ok(bound_msg)
445            }
446            MessageNum::RawV1(v) => Ok(pythonize(py, v)?),
447        }
448    }
449
450    pub fn to_bound_openai_message<'py>(
451        &self,
452        py: Python<'py>,
453    ) -> Result<Bound<'py, OpenAIChatMessage>, TypeError> {
454        match self {
455            MessageNum::OpenAIMessageV1(msg) => {
456                let py_obj = Py::new(py, msg.clone())?;
457                let bound = py_obj.bind(py);
458                Ok(bound.clone())
459            }
460            _ => Err(TypeError::CantConvertSelf),
461        }
462    }
463
464    pub fn to_bound_gemini_message<'py>(
465        &self,
466        py: Python<'py>,
467    ) -> Result<Bound<'py, GeminiContent>, TypeError> {
468        match self {
469            MessageNum::GeminiContentV1(msg) => {
470                let py_obj = Py::new(py, msg.clone())?;
471                let bound = py_obj.bind(py);
472                Ok(bound.clone())
473            }
474            _ => Err(TypeError::CantConvertSelf),
475        }
476    }
477
478    pub fn to_bound_anthropic_message<'py>(
479        &self,
480        py: Python<'py>,
481    ) -> Result<Bound<'py, AnthropicMessage>, TypeError> {
482        match self {
483            MessageNum::AnthropicMessageV1(msg) => {
484                let py_obj = Py::new(py, msg.clone())?;
485                let bound = py_obj.bind(py);
486                Ok(bound.clone())
487            }
488            _ => Err(TypeError::CantConvertSelf),
489        }
490    }
491
492    pub fn is_system_message(&self) -> bool {
493        match self {
494            MessageNum::OpenAIMessageV1(msg) => {
495                msg.role == Role::Developer.to_string() || msg.role == Role::System.to_string()
496            }
497            MessageNum::AnthropicMessageV1(msg) => msg.role == Role::System.to_string(),
498            MessageNum::GeminiContentV1(msg) => msg.role == Role::Model.to_string(),
499            MessageNum::AnthropicSystemMessageV1(_) => true,
500            MessageNum::RawV1(v) => v
501                .get("role")
502                .and_then(|r| r.as_str())
503                .map(|r| r == "system" || r == "developer")
504                .unwrap_or(false),
505        }
506    }
507
508    pub fn is_user_message(&self) -> bool {
509        match self {
510            MessageNum::OpenAIMessageV1(msg) => msg.role == Role::User.to_string(),
511            MessageNum::AnthropicMessageV1(msg) => msg.role == Role::User.to_string(),
512            MessageNum::GeminiContentV1(msg) => msg.role == Role::User.to_string(),
513            MessageNum::AnthropicSystemMessageV1(_) => false,
514            MessageNum::RawV1(v) => v
515                .get("role")
516                .and_then(|r| r.as_str())
517                .map(|r| r == "user")
518                .unwrap_or(false),
519        }
520    }
521}
522
523#[derive(Debug, Clone)]
524pub enum ResponseContent {
525    OpenAI(Choice),
526    Google(Candidate),
527    Anthropic(ResponseContentBlock),
528    PredictResponse(PredictResponse),
529}
530
531#[pyclass]
532pub struct OpenAIMessageList {
533    pub messages: Vec<OpenAIChatMessage>,
534}
535
536#[pymethods]
537impl OpenAIMessageList {
538    fn __iter__(slf: PyRef<'_, Self>) -> PyResult<Py<OpenAIMessageIterator>> {
539        let iter = OpenAIMessageIterator {
540            inner: slf.messages.clone().into_iter(),
541        };
542        Py::new(slf.py(), iter)
543    }
544
545    pub fn __len__(&self) -> usize {
546        self.messages.len()
547    }
548
549    pub fn __getitem__(&self, index: isize) -> Result<OpenAIChatMessage, TypeError> {
550        let len = self.messages.len() as isize;
551        let normalized_index = if index < 0 { len + index } else { index };
552
553        if normalized_index < 0 || normalized_index >= len {
554            return Err(TypeError::Error(format!(
555                "Index {} out of range for list of length {}",
556                index, len
557            )));
558        }
559
560        Ok(self.messages[normalized_index as usize].clone())
561    }
562
563    pub fn __str__(&self) -> String {
564        PyHelperFuncs::__str__(&self.messages)
565    }
566
567    pub fn __repr__(&self) -> String {
568        self.__str__()
569    }
570}
571
572#[pyclass]
573pub struct OpenAIMessageIterator {
574    inner: std::vec::IntoIter<OpenAIChatMessage>,
575}
576
577#[pymethods]
578impl OpenAIMessageIterator {
579    fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> {
580        slf
581    }
582
583    fn __next__(mut slf: PyRefMut<'_, Self>) -> Option<OpenAIChatMessage> {
584        slf.inner.next()
585    }
586}
587
588#[pyclass]
589pub struct AnthropicMessageList {
590    pub messages: Vec<AnthropicMessage>,
591}
592
593#[pymethods]
594impl AnthropicMessageList {
595    fn __iter__(slf: PyRef<'_, Self>) -> PyResult<Py<AnthropicMessageIterator>> {
596        let iter = AnthropicMessageIterator {
597            inner: slf.messages.clone().into_iter(),
598        };
599        Py::new(slf.py(), iter)
600    }
601
602    pub fn __len__(&self) -> usize {
603        self.messages.len()
604    }
605
606    pub fn __getitem__(&self, index: isize) -> Result<AnthropicMessage, TypeError> {
607        let len = self.messages.len() as isize;
608        let normalized_index = if index < 0 { len + index } else { index };
609
610        if normalized_index < 0 || normalized_index >= len {
611            return Err(TypeError::Error(format!(
612                "Index {} out of range for list of length {}",
613                index, len
614            )));
615        }
616
617        Ok(self.messages[normalized_index as usize].clone())
618    }
619
620    pub fn __str__(&self) -> String {
621        PyHelperFuncs::__str__(&self.messages)
622    }
623
624    pub fn __repr__(&self) -> String {
625        self.__str__()
626    }
627}
628
629#[pyclass]
630pub struct AnthropicMessageIterator {
631    inner: std::vec::IntoIter<AnthropicMessage>,
632}
633
634#[pymethods]
635impl AnthropicMessageIterator {
636    fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> {
637        slf
638    }
639
640    fn __next__(mut slf: PyRefMut<'_, Self>) -> Option<AnthropicMessage> {
641        slf.inner.next()
642    }
643}
644
645#[pyclass]
646pub struct GeminiContentList {
647    pub messages: Vec<GeminiContent>,
648}
649
650#[pymethods]
651impl GeminiContentList {
652    fn __iter__(slf: PyRef<'_, Self>) -> PyResult<Py<GeminiContentIterator>> {
653        let iter = GeminiContentIterator {
654            inner: slf.messages.clone().into_iter(),
655        };
656        Py::new(slf.py(), iter)
657    }
658
659    pub fn __len__(&self) -> usize {
660        self.messages.len()
661    }
662
663    pub fn __getitem__(&self, index: isize) -> Result<GeminiContent, TypeError> {
664        let len = self.messages.len() as isize;
665        let normalized_index = if index < 0 { len + index } else { index };
666
667        if normalized_index < 0 || normalized_index >= len {
668            return Err(TypeError::Error(format!(
669                "Index {} out of range for list of length {}",
670                index, len
671            )));
672        }
673
674        Ok(self.messages[normalized_index as usize].clone())
675    }
676
677    pub fn __str__(&self) -> String {
678        PyHelperFuncs::__str__(&self.messages)
679    }
680
681    pub fn __repr__(&self) -> String {
682        self.__str__()
683    }
684}
685
686#[pyclass]
687pub struct GeminiContentIterator {
688    inner: std::vec::IntoIter<GeminiContent>,
689}
690
691#[pymethods]
692impl GeminiContentIterator {
693    fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> {
694        slf
695    }
696
697    fn __next__(mut slf: PyRefMut<'_, Self>) -> Option<GeminiContent> {
698        slf.inner.next()
699    }
700}