Skip to main content

potato_type/prompt/
types.rs

1use crate::anthropic::v1::request::MessageParam as AnthropicMessage;
2use crate::anthropic::v1::request::TextBlockParam;
3use crate::anthropic::v1::response::ResponseContentBlock;
4use crate::google::v1::generate::Candidate;
5use crate::google::v1::generate::GeminiContent;
6use crate::google::PredictResponse;
7use crate::openai::v1::chat::request::ChatMessage as OpenAIChatMessage;
8use crate::openai::v1::Choice;
9use crate::traits::MessageConversion;
10use crate::traits::PromptMessageExt;
11use crate::Provider;
12use crate::{StructuredOutput, TypeError};
13use potato_util::PyHelperFuncs;
14use pyo3::types::PyAnyMethods;
15use pyo3::{prelude::*, IntoPyObjectExt};
16use pythonize::{depythonize, pythonize};
17use schemars::JsonSchema;
18use serde::{Deserialize, Serialize};
19use serde_json::Value;
20use std::fmt::Display;
21use tracing::error;
22#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
23#[pyclass]
24pub enum Role {
25    User,
26    Assistant,
27    Developer,
28    Tool,
29    Model,
30    System,
31}
32
33#[pymethods]
34impl Role {
35    #[pyo3(name = "as_str")]
36    pub fn as_str_py(&self) -> &'static str {
37        self.as_str()
38    }
39}
40
41impl Role {
42    /// Returns the string representation of the role
43    pub const fn as_str(&self) -> &'static str {
44        match self {
45            Role::User => "user",
46            Role::Assistant => "assistant",
47            Role::Developer => "developer",
48            Role::Tool => "tool",
49            Role::Model => "model",
50            Role::System => "system",
51        }
52    }
53}
54
55impl Display for Role {
56    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
57        match self {
58            Role::User => write!(f, "user"),
59            Role::Assistant => write!(f, "assistant"),
60            Role::Developer => write!(f, "developer"),
61            Role::Tool => write!(f, "tool"),
62            Role::Model => write!(f, "model"),
63            Role::System => write!(f, "system"),
64        }
65    }
66}
67
68impl From<Role> for &str {
69    fn from(role: Role) -> Self {
70        match role {
71            Role::User => "user",
72            Role::Assistant => "assistant",
73            Role::Developer => "developer",
74            Role::Tool => "tool",
75            Role::Model => "model",
76            Role::System => "system",
77        }
78    }
79}
80
81pub trait DeserializePromptValExt: for<'de> serde::Deserialize<'de> {
82    /// Validates and deserializes a JSON value into its struct type.
83    ///
84    /// # Arguments
85    /// * `value` - The JSON value to deserialize
86    ///
87    /// # Returns
88    /// * `Result<Self, serde_json::Error>` - The deserialized value or error
89    fn model_validate_json(value: &Value) -> Result<Self, serde_json::Error> {
90        serde_json::from_value(value.clone())
91    }
92}
93
94pub fn get_pydantic_module<'py>(py: Python<'py>, module_name: &str) -> PyResult<Bound<'py, PyAny>> {
95    py.import("pydantic_ai")?.getattr(module_name)
96}
97
98/// Checks if an object is a subclass of a pydantic BaseModel. This is used when validating structured outputs
99/// # Arguments
100/// * `py` - The Python interpreter instance
101/// * `object` - The object to check
102/// # Returns
103/// A boolean indicating whether the object is a subclass of pydantic.BaseModel
104pub fn check_pydantic_model<'py>(
105    py: Python<'py>,
106    object: &Bound<'_, PyAny>,
107) -> Result<bool, TypeError> {
108    // check pydantic import. Return false if it fails
109    let pydantic = match py.import("pydantic").map_err(|e| {
110        error!("Failed to import pydantic: {}", e);
111        false
112    }) {
113        Ok(pydantic) => pydantic,
114        Err(_) => return Ok(false),
115    };
116
117    // get builtin subclass
118    let is_subclass = py.import("builtins")?.getattr("issubclass")?;
119
120    // Need to check if provided object is a basemodel
121    let basemodel = pydantic.getattr("BaseModel")?;
122    let matched = is_subclass.call1((object, basemodel))?.extract::<bool>()?;
123
124    Ok(matched)
125}
126
127/// Generate a JSON schema from a pydantic BaseModel object.
128/// # Arguments
129/// * `object` - The pydantic BaseModel object to generate the schema from.
130/// # Returns
131/// A JSON schema as a serde_json::Value.
132fn get_json_schema_from_basemodel(object: &Bound<'_, PyAny>) -> Result<Value, TypeError> {
133    // call staticmethod .model_json_schema()
134    let schema = object.getattr("model_json_schema")?.call1(())?;
135
136    let mut schema: Value = depythonize(&schema)?;
137
138    // ensure schema as additionalProperties set to false
139    if let Some(additional_properties) = schema.get_mut("additionalProperties") {
140        *additional_properties = serde_json::json!(false);
141    } else {
142        schema
143            .as_object_mut()
144            .unwrap()
145            .insert("additionalProperties".to_string(), serde_json::json!(false));
146    }
147
148    Ok(schema)
149}
150
151fn parse_pydantic_model<'py>(
152    py: Python<'py>,
153    object: &Bound<'_, PyAny>,
154) -> Result<Option<Value>, TypeError> {
155    let is_subclass = check_pydantic_model(py, object)?;
156    if is_subclass {
157        Ok(Some(get_json_schema_from_basemodel(object)?))
158    } else {
159        Ok(None)
160    }
161}
162
163pub fn check_response_type(object: &Bound<'_, PyAny>) -> Result<Option<ResponseType>, TypeError> {
164    // try calling staticmethod response_type()
165    let response_type = match object.getattr("response_type") {
166        Ok(method) => {
167            if method.is_callable() {
168                let response_type: ResponseType = method.call0()?.extract()?;
169                Some(response_type)
170            } else {
171                None
172            }
173        }
174        Err(_) => None,
175    };
176
177    Ok(response_type)
178}
179
180fn get_json_schema_from_response_type(response_type: &ResponseType) -> Result<Value, TypeError> {
181    match response_type {
182        ResponseType::Score => Ok(Score::get_structured_output_schema()),
183        _ => {
184            // If the response type is not recognized, return None
185            Err(TypeError::Error(format!(
186                "Unsupported response type: {response_type}"
187            )))
188        }
189    }
190}
191
192pub fn parse_response_to_json<'py>(
193    py: Python<'py>,
194    object: &Bound<'_, PyAny>,
195) -> Result<(ResponseType, Option<Value>), TypeError> {
196    // check if object is a pydantic model
197    let is_pydantic_model = check_pydantic_model(py, object)?;
198    if is_pydantic_model {
199        return Ok((ResponseType::Pydantic, parse_pydantic_model(py, object)?));
200    }
201
202    // check if object has response_type method
203    let response_type = check_response_type(object)?;
204    if let Some(response_type) = response_type {
205        return Ok((
206            response_type.clone(),
207            Some(get_json_schema_from_response_type(&response_type)?),
208        ));
209    }
210
211    Ok((ResponseType::Null, None))
212}
213
214#[pyclass]
215#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
216#[serde(deny_unknown_fields)] // ensure strict validation
217pub struct Score {
218    #[pyo3(get)]
219    #[schemars(range(min = 1, max = 5))]
220    pub score: i64,
221
222    #[pyo3(get)]
223    pub reason: String,
224}
225#[pymethods]
226impl Score {
227    #[staticmethod]
228    pub fn response_type() -> ResponseType {
229        ResponseType::Score
230    }
231
232    #[staticmethod]
233    pub fn model_validate_json(json_string: String) -> Result<Score, TypeError> {
234        Ok(serde_json::from_str(&json_string)?)
235    }
236
237    #[staticmethod]
238    pub fn model_json_schema<'py>(py: Python<'py>) -> Result<Bound<'py, PyAny>, TypeError> {
239        let schema = Score::get_structured_output_schema();
240        Ok(pythonize(py, &schema)?)
241    }
242
243    pub fn __str__(&self) -> String {
244        PyHelperFuncs::__str__(self)
245    }
246}
247
248impl StructuredOutput for Score {}
249
250#[pyclass]
251#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
252pub enum ResponseType {
253    Score,
254    Pydantic,
255    #[default]
256    Null, // This is used when no response type is specified
257}
258
259impl Display for ResponseType {
260    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
261        match self {
262            ResponseType::Score => write!(f, "Score"),
263            ResponseType::Pydantic => write!(f, "Pydantic"),
264            ResponseType::Null => write!(f, "Null"),
265        }
266    }
267}
268
269// add conversion logic based on message conversion trait
270
271#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
272#[serde(untagged)]
273pub enum MessageNum {
274    OpenAIMessageV1(OpenAIChatMessage),
275    AnthropicMessageV1(AnthropicMessage),
276    GeminiContentV1(GeminiContent),
277
278    // this is a special case for Anthropic system messages
279    AnthropicSystemMessageV1(TextBlockParam),
280}
281
282impl MessageNum {
283    /// Checks if the message type matches the given provider
284    fn matches_provider(&self, provider: &Provider) -> bool {
285        matches!(
286            (self, provider),
287            (MessageNum::OpenAIMessageV1(_), Provider::OpenAI)
288                | (MessageNum::AnthropicMessageV1(_), Provider::Anthropic)
289                | (MessageNum::AnthropicSystemMessageV1(_), Provider::Anthropic)
290                | (MessageNum::GeminiContentV1(_), Provider::Google)
291                | (MessageNum::GeminiContentV1(_), Provider::Vertex)
292                | (MessageNum::GeminiContentV1(_), Provider::Gemini)
293        )
294    }
295
296    /// Converts the message to an openai message
297    /// This is only done for anthropic and gemini messages
298    /// openai message will return a failure if called on an openai message
299    /// Control flow should ensure this is only called on non-openai messages
300    fn to_openai_message(&self) -> Result<MessageNum, TypeError> {
301        match self {
302            MessageNum::AnthropicMessageV1(msg) => {
303                Ok(MessageNum::OpenAIMessageV1(msg.to_openai_message()?))
304            }
305            MessageNum::GeminiContentV1(msg) => {
306                Ok(MessageNum::OpenAIMessageV1(msg.to_openai_message()?))
307            }
308            _ => Err(TypeError::CantConvertSelf),
309        }
310    }
311
312    /// Converts to Anthropic message format
313    fn to_anthropic_message(&self) -> Result<MessageNum, TypeError> {
314        match self {
315            MessageNum::OpenAIMessageV1(msg) => {
316                Ok(MessageNum::AnthropicMessageV1(msg.to_anthropic_message()?))
317            }
318            MessageNum::GeminiContentV1(msg) => {
319                Ok(MessageNum::AnthropicMessageV1(msg.to_anthropic_message()?))
320            }
321            _ => Err(TypeError::CantConvertSelf),
322        }
323    }
324
325    /// Converts to Google Gemini message format
326    fn to_google_message(&self) -> Result<MessageNum, TypeError> {
327        match self {
328            MessageNum::OpenAIMessageV1(msg) => {
329                Ok(MessageNum::GeminiContentV1(msg.to_google_message()?))
330            }
331            MessageNum::AnthropicMessageV1(msg) => {
332                Ok(MessageNum::GeminiContentV1(msg.to_google_message()?))
333            }
334            _ => Err(TypeError::CantConvertSelf),
335        }
336    }
337
338    fn convert_message_to_provider_type(
339        &self,
340        provider: &Provider,
341    ) -> Result<MessageNum, TypeError> {
342        match provider {
343            Provider::OpenAI => self.to_openai_message(),
344            Provider::Anthropic => self.to_anthropic_message(),
345            Provider::Google => self.to_google_message(),
346            Provider::Vertex => self.to_google_message(),
347            Provider::Gemini => self.to_google_message(),
348            _ => Err(TypeError::UnsupportedProviderError),
349        }
350    }
351
352    pub fn convert_message(&mut self, provider: &Provider) -> Result<(), TypeError> {
353        // if message already matches provider, return Ok
354        if self.matches_provider(provider) {
355            return Ok(());
356        }
357        let converted = self.convert_message_to_provider_type(provider)?;
358        *self = converted;
359        Ok(())
360    }
361
362    pub fn anthropic_message_to_system_message(&mut self) -> Result<(), TypeError> {
363        match self {
364            MessageNum::AnthropicMessageV1(msg) => {
365                let text_param = msg.to_text_block_param()?;
366                *self = MessageNum::AnthropicSystemMessageV1(text_param);
367                Ok(())
368            }
369            _ => Err(TypeError::Error(
370                "Cannot convert non-AnthropicMessageV1 to system message".to_string(),
371            )),
372        }
373    }
374    pub fn role(&self) -> &str {
375        match self {
376            MessageNum::OpenAIMessageV1(msg) => &msg.role,
377            MessageNum::AnthropicMessageV1(msg) => &msg.role,
378            MessageNum::GeminiContentV1(msg) => &msg.role,
379            _ => "system",
380        }
381    }
382    pub fn bind(&self, name: &str, value: &str) -> Result<Self, TypeError> {
383        match self {
384            MessageNum::OpenAIMessageV1(msg) => {
385                let bound_msg = msg.bind(name, value)?;
386                Ok(MessageNum::OpenAIMessageV1(bound_msg))
387            }
388            MessageNum::AnthropicMessageV1(msg) => {
389                let bound_msg = msg.bind(name, value)?;
390                Ok(MessageNum::AnthropicMessageV1(bound_msg))
391            }
392            MessageNum::GeminiContentV1(msg) => {
393                let bound_msg = msg.bind(name, value)?;
394                Ok(MessageNum::GeminiContentV1(bound_msg))
395            }
396            _ => Ok(self.clone()),
397        }
398    }
399    pub fn bind_mut(&mut self, name: &str, value: &str) -> Result<(), TypeError> {
400        match self {
401            MessageNum::OpenAIMessageV1(msg) => msg.bind_mut(name, value),
402            MessageNum::AnthropicMessageV1(msg) => msg.bind_mut(name, value),
403            MessageNum::GeminiContentV1(msg) => msg.bind_mut(name, value),
404            _ => Ok(()),
405        }
406    }
407
408    pub(crate) fn extract_variables(&self) -> Vec<String> {
409        match self {
410            MessageNum::OpenAIMessageV1(msg) => msg.extract_variables(),
411            MessageNum::AnthropicMessageV1(msg) => msg.extract_variables(),
412            MessageNum::GeminiContentV1(msg) => msg.extract_variables(),
413            _ => vec![],
414        }
415    }
416
417    pub fn to_bound_py_object<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyAny>, TypeError> {
418        match self {
419            MessageNum::OpenAIMessageV1(msg) => {
420                let bound_msg = msg.clone().into_bound_py_any(py)?;
421                Ok(bound_msg)
422            }
423            MessageNum::AnthropicMessageV1(msg) => {
424                let bound_msg = msg.clone().into_bound_py_any(py)?;
425                Ok(bound_msg)
426            }
427            MessageNum::GeminiContentV1(msg) => {
428                let bound_msg = msg.clone().into_bound_py_any(py)?;
429                Ok(bound_msg)
430            }
431            MessageNum::AnthropicSystemMessageV1(msg) => {
432                // Convert to AnthropicMessage first
433                let anthropic_msg = AnthropicMessage::from_text(msg.text.clone(), self.role())?;
434                let bound_msg = anthropic_msg.into_bound_py_any(py)?;
435                Ok(bound_msg)
436            }
437        }
438    }
439
440    pub fn to_bound_openai_message<'py>(
441        &self,
442        py: Python<'py>,
443    ) -> Result<Bound<'py, OpenAIChatMessage>, TypeError> {
444        match self {
445            MessageNum::OpenAIMessageV1(msg) => {
446                let py_obj = Py::new(py, msg.clone())?;
447                let bound = py_obj.bind(py);
448                Ok(bound.clone())
449            }
450            _ => Err(TypeError::CantConvertSelf),
451        }
452    }
453
454    pub fn to_bound_gemini_message<'py>(
455        &self,
456        py: Python<'py>,
457    ) -> Result<Bound<'py, GeminiContent>, TypeError> {
458        match self {
459            MessageNum::GeminiContentV1(msg) => {
460                let py_obj = Py::new(py, msg.clone())?;
461                let bound = py_obj.bind(py);
462                Ok(bound.clone())
463            }
464            _ => Err(TypeError::CantConvertSelf),
465        }
466    }
467
468    pub fn to_bound_anthropic_message<'py>(
469        &self,
470        py: Python<'py>,
471    ) -> Result<Bound<'py, AnthropicMessage>, TypeError> {
472        match self {
473            MessageNum::AnthropicMessageV1(msg) => {
474                let py_obj = Py::new(py, msg.clone())?;
475                let bound = py_obj.bind(py);
476                Ok(bound.clone())
477            }
478            _ => Err(TypeError::CantConvertSelf),
479        }
480    }
481
482    pub fn is_system_message(&self) -> bool {
483        match self {
484            MessageNum::OpenAIMessageV1(msg) => {
485                msg.role == Role::Developer.to_string() || msg.role == Role::System.to_string()
486            }
487            MessageNum::AnthropicMessageV1(msg) => msg.role == Role::System.to_string(),
488            MessageNum::GeminiContentV1(msg) => msg.role == Role::Model.to_string(),
489            MessageNum::AnthropicSystemMessageV1(_) => true,
490        }
491    }
492
493    pub fn is_user_message(&self) -> bool {
494        match self {
495            MessageNum::OpenAIMessageV1(msg) => msg.role == Role::User.to_string(),
496            MessageNum::AnthropicMessageV1(msg) => msg.role == Role::User.to_string(),
497            MessageNum::GeminiContentV1(msg) => msg.role == Role::User.to_string(),
498            MessageNum::AnthropicSystemMessageV1(_) => false,
499        }
500    }
501}
502
503#[derive(Debug, Clone)]
504pub enum ResponseContent {
505    OpenAI(Choice),
506    Google(Candidate),
507    Anthropic(ResponseContentBlock),
508    PredictResponse(PredictResponse),
509}
510
511#[pyclass]
512pub struct OpenAIMessageList {
513    pub messages: Vec<OpenAIChatMessage>,
514}
515
516#[pymethods]
517impl OpenAIMessageList {
518    fn __iter__(slf: PyRef<'_, Self>) -> PyResult<Py<OpenAIMessageIterator>> {
519        let iter = OpenAIMessageIterator {
520            inner: slf.messages.clone().into_iter(),
521        };
522        Py::new(slf.py(), iter)
523    }
524
525    pub fn __len__(&self) -> usize {
526        self.messages.len()
527    }
528
529    pub fn __getitem__(&self, index: isize) -> Result<OpenAIChatMessage, TypeError> {
530        let len = self.messages.len() as isize;
531        let normalized_index = if index < 0 { len + index } else { index };
532
533        if normalized_index < 0 || normalized_index >= len {
534            return Err(TypeError::Error(format!(
535                "Index {} out of range for list of length {}",
536                index, len
537            )));
538        }
539
540        Ok(self.messages[normalized_index as usize].clone())
541    }
542
543    pub fn __str__(&self) -> String {
544        PyHelperFuncs::__str__(&self.messages)
545    }
546
547    pub fn __repr__(&self) -> String {
548        self.__str__()
549    }
550}
551
552#[pyclass]
553pub struct OpenAIMessageIterator {
554    inner: std::vec::IntoIter<OpenAIChatMessage>,
555}
556
557#[pymethods]
558impl OpenAIMessageIterator {
559    fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> {
560        slf
561    }
562
563    fn __next__(mut slf: PyRefMut<'_, Self>) -> Option<OpenAIChatMessage> {
564        slf.inner.next()
565    }
566}
567
568#[pyclass]
569pub struct AnthropicMessageList {
570    pub messages: Vec<AnthropicMessage>,
571}
572
573#[pymethods]
574impl AnthropicMessageList {
575    fn __iter__(slf: PyRef<'_, Self>) -> PyResult<Py<AnthropicMessageIterator>> {
576        let iter = AnthropicMessageIterator {
577            inner: slf.messages.clone().into_iter(),
578        };
579        Py::new(slf.py(), iter)
580    }
581
582    pub fn __len__(&self) -> usize {
583        self.messages.len()
584    }
585
586    pub fn __getitem__(&self, index: isize) -> Result<AnthropicMessage, TypeError> {
587        let len = self.messages.len() as isize;
588        let normalized_index = if index < 0 { len + index } else { index };
589
590        if normalized_index < 0 || normalized_index >= len {
591            return Err(TypeError::Error(format!(
592                "Index {} out of range for list of length {}",
593                index, len
594            )));
595        }
596
597        Ok(self.messages[normalized_index as usize].clone())
598    }
599
600    pub fn __str__(&self) -> String {
601        PyHelperFuncs::__str__(&self.messages)
602    }
603
604    pub fn __repr__(&self) -> String {
605        self.__str__()
606    }
607}
608
609#[pyclass]
610pub struct AnthropicMessageIterator {
611    inner: std::vec::IntoIter<AnthropicMessage>,
612}
613
614#[pymethods]
615impl AnthropicMessageIterator {
616    fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> {
617        slf
618    }
619
620    fn __next__(mut slf: PyRefMut<'_, Self>) -> Option<AnthropicMessage> {
621        slf.inner.next()
622    }
623}
624
625#[pyclass]
626pub struct GeminiContentList {
627    pub messages: Vec<GeminiContent>,
628}
629
630#[pymethods]
631impl GeminiContentList {
632    fn __iter__(slf: PyRef<'_, Self>) -> PyResult<Py<GeminiContentIterator>> {
633        let iter = GeminiContentIterator {
634            inner: slf.messages.clone().into_iter(),
635        };
636        Py::new(slf.py(), iter)
637    }
638
639    pub fn __len__(&self) -> usize {
640        self.messages.len()
641    }
642
643    pub fn __getitem__(&self, index: isize) -> Result<GeminiContent, TypeError> {
644        let len = self.messages.len() as isize;
645        let normalized_index = if index < 0 { len + index } else { index };
646
647        if normalized_index < 0 || normalized_index >= len {
648            return Err(TypeError::Error(format!(
649                "Index {} out of range for list of length {}",
650                index, len
651            )));
652        }
653
654        Ok(self.messages[normalized_index as usize].clone())
655    }
656
657    pub fn __str__(&self) -> String {
658        PyHelperFuncs::__str__(&self.messages)
659    }
660
661    pub fn __repr__(&self) -> String {
662        self.__str__()
663    }
664}
665
666#[pyclass]
667pub struct GeminiContentIterator {
668    inner: std::vec::IntoIter<GeminiContent>,
669}
670
671#[pymethods]
672impl GeminiContentIterator {
673    fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> {
674        slf
675    }
676
677    fn __next__(mut slf: PyRefMut<'_, Self>) -> Option<GeminiContent> {
678        slf.inner.next()
679    }
680}