1use crate::anthropic::v1::request::MessageParam as AnthropicMessage;
2use crate::anthropic::v1::request::TextBlockParam;
3use crate::anthropic::v1::response::ResponseContentBlock;
4use crate::google::v1::generate::Candidate;
5use crate::google::v1::generate::GeminiContent;
6use crate::google::PredictResponse;
7use crate::openai::v1::chat::request::ChatMessage as OpenAIChatMessage;
8use crate::openai::v1::Choice;
9use crate::traits::MessageConversion;
10use crate::traits::PromptMessageExt;
11use crate::Provider;
12use crate::{StructuredOutput, TypeError};
13use potato_util::PyHelperFuncs;
14use pyo3::types::PyAnyMethods;
15use pyo3::{prelude::*, IntoPyObjectExt};
16use pythonize::{depythonize, pythonize};
17use schemars::JsonSchema;
18use serde::{Deserialize, Serialize};
19use serde_json::Value;
20use std::fmt::Display;
21use tracing::error;
22#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
23#[pyclass]
24pub enum Role {
25 User,
26 Assistant,
27 Developer,
28 Tool,
29 Model,
30 System,
31}
32
33#[pymethods]
34impl Role {
35 #[pyo3(name = "as_str")]
36 pub fn as_str_py(&self) -> &'static str {
37 self.as_str()
38 }
39}
40
41impl Role {
42 pub const fn as_str(&self) -> &'static str {
44 match self {
45 Role::User => "user",
46 Role::Assistant => "assistant",
47 Role::Developer => "developer",
48 Role::Tool => "tool",
49 Role::Model => "model",
50 Role::System => "system",
51 }
52 }
53}
54
55impl Display for Role {
56 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
57 match self {
58 Role::User => write!(f, "user"),
59 Role::Assistant => write!(f, "assistant"),
60 Role::Developer => write!(f, "developer"),
61 Role::Tool => write!(f, "tool"),
62 Role::Model => write!(f, "model"),
63 Role::System => write!(f, "system"),
64 }
65 }
66}
67
68impl From<Role> for &str {
69 fn from(role: Role) -> Self {
70 match role {
71 Role::User => "user",
72 Role::Assistant => "assistant",
73 Role::Developer => "developer",
74 Role::Tool => "tool",
75 Role::Model => "model",
76 Role::System => "system",
77 }
78 }
79}
80
81pub trait DeserializePromptValExt: for<'de> serde::Deserialize<'de> {
82 fn model_validate_json(value: &Value) -> Result<Self, serde_json::Error> {
90 serde_json::from_value(value.clone())
91 }
92}
93
94pub fn get_pydantic_module<'py>(py: Python<'py>, module_name: &str) -> PyResult<Bound<'py, PyAny>> {
95 py.import("pydantic_ai")?.getattr(module_name)
96}
97
98pub fn check_pydantic_model<'py>(
105 py: Python<'py>,
106 object: &Bound<'_, PyAny>,
107) -> Result<bool, TypeError> {
108 let pydantic = match py.import("pydantic").map_err(|e| {
110 error!("Failed to import pydantic: {}", e);
111 false
112 }) {
113 Ok(pydantic) => pydantic,
114 Err(_) => return Ok(false),
115 };
116
117 let is_subclass = py.import("builtins")?.getattr("issubclass")?;
119
120 let basemodel = pydantic.getattr("BaseModel")?;
122 let matched = is_subclass.call1((object, basemodel))?.extract::<bool>()?;
123
124 Ok(matched)
125}
126
127fn get_json_schema_from_basemodel(object: &Bound<'_, PyAny>) -> Result<Value, TypeError> {
133 let schema = object.getattr("model_json_schema")?.call1(())?;
135
136 let mut schema: Value = depythonize(&schema)?;
137
138 if let Some(additional_properties) = schema.get_mut("additionalProperties") {
140 *additional_properties = serde_json::json!(false);
141 } else {
142 schema
143 .as_object_mut()
144 .unwrap()
145 .insert("additionalProperties".to_string(), serde_json::json!(false));
146 }
147
148 Ok(schema)
149}
150
151fn parse_pydantic_model<'py>(
152 py: Python<'py>,
153 object: &Bound<'_, PyAny>,
154) -> Result<Option<Value>, TypeError> {
155 let is_subclass = check_pydantic_model(py, object)?;
156 if is_subclass {
157 Ok(Some(get_json_schema_from_basemodel(object)?))
158 } else {
159 Ok(None)
160 }
161}
162
163pub fn check_response_type(object: &Bound<'_, PyAny>) -> Result<Option<ResponseType>, TypeError> {
164 let response_type = match object.getattr("response_type") {
166 Ok(method) => {
167 if method.is_callable() {
168 let response_type: ResponseType = method.call0()?.extract()?;
169 Some(response_type)
170 } else {
171 None
172 }
173 }
174 Err(_) => None,
175 };
176
177 Ok(response_type)
178}
179
180fn get_json_schema_from_response_type(response_type: &ResponseType) -> Result<Value, TypeError> {
181 match response_type {
182 ResponseType::Score => Ok(Score::get_structured_output_schema()),
183 _ => {
184 Err(TypeError::Error(format!(
186 "Unsupported response type: {response_type}"
187 )))
188 }
189 }
190}
191
192pub fn parse_response_to_json<'py>(
193 py: Python<'py>,
194 object: &Bound<'_, PyAny>,
195) -> Result<(ResponseType, Option<Value>), TypeError> {
196 let is_pydantic_model = check_pydantic_model(py, object)?;
198 if is_pydantic_model {
199 return Ok((ResponseType::Pydantic, parse_pydantic_model(py, object)?));
200 }
201
202 let response_type = check_response_type(object)?;
204 if let Some(response_type) = response_type {
205 return Ok((
206 response_type.clone(),
207 Some(get_json_schema_from_response_type(&response_type)?),
208 ));
209 }
210
211 Ok((ResponseType::Null, None))
212}
213
214#[pyclass]
215#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
216#[serde(deny_unknown_fields)] pub struct Score {
218 #[pyo3(get)]
219 #[schemars(range(min = 1, max = 5))]
220 pub score: i64,
221
222 #[pyo3(get)]
223 pub reason: String,
224}
225#[pymethods]
226impl Score {
227 #[staticmethod]
228 pub fn response_type() -> ResponseType {
229 ResponseType::Score
230 }
231
232 #[staticmethod]
233 pub fn model_validate_json(json_string: String) -> Result<Score, TypeError> {
234 Ok(serde_json::from_str(&json_string)?)
235 }
236
237 #[staticmethod]
238 pub fn model_json_schema<'py>(py: Python<'py>) -> Result<Bound<'py, PyAny>, TypeError> {
239 let schema = Score::get_structured_output_schema();
240 Ok(pythonize(py, &schema)?)
241 }
242
243 pub fn __str__(&self) -> String {
244 PyHelperFuncs::__str__(self)
245 }
246}
247
248impl StructuredOutput for Score {}
249
250#[pyclass]
251#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
252pub enum ResponseType {
253 Score,
254 Pydantic,
255 #[default]
256 Null, }
258
259impl Display for ResponseType {
260 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
261 match self {
262 ResponseType::Score => write!(f, "Score"),
263 ResponseType::Pydantic => write!(f, "Pydantic"),
264 ResponseType::Null => write!(f, "Null"),
265 }
266 }
267}
268
269#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
272#[serde(untagged)]
273pub enum MessageNum {
274 OpenAIMessageV1(OpenAIChatMessage),
275 AnthropicMessageV1(AnthropicMessage),
276 GeminiContentV1(GeminiContent),
277
278 AnthropicSystemMessageV1(TextBlockParam),
280
281 RawV1(serde_json::Value),
285}
286
287impl MessageNum {
288 fn matches_provider(&self, provider: &Provider) -> bool {
290 matches!(
291 (self, provider),
292 (MessageNum::OpenAIMessageV1(_), Provider::OpenAI)
293 | (MessageNum::AnthropicMessageV1(_), Provider::Anthropic)
294 | (MessageNum::AnthropicSystemMessageV1(_), Provider::Anthropic)
295 | (MessageNum::GeminiContentV1(_), Provider::Google)
296 | (MessageNum::GeminiContentV1(_), Provider::Vertex)
297 | (MessageNum::GeminiContentV1(_), Provider::Gemini)
298 )
299 }
300
301 fn to_openai_message(&self) -> Result<MessageNum, TypeError> {
306 match self {
307 MessageNum::AnthropicMessageV1(msg) => {
308 Ok(MessageNum::OpenAIMessageV1(msg.to_openai_message()?))
309 }
310 MessageNum::GeminiContentV1(msg) => {
311 Ok(MessageNum::OpenAIMessageV1(msg.to_openai_message()?))
312 }
313 _ => Err(TypeError::CantConvertSelf),
314 }
315 }
316
317 fn to_anthropic_message(&self) -> Result<MessageNum, TypeError> {
319 match self {
320 MessageNum::OpenAIMessageV1(msg) => {
321 Ok(MessageNum::AnthropicMessageV1(msg.to_anthropic_message()?))
322 }
323 MessageNum::GeminiContentV1(msg) => {
324 Ok(MessageNum::AnthropicMessageV1(msg.to_anthropic_message()?))
325 }
326 _ => Err(TypeError::CantConvertSelf),
327 }
328 }
329
330 fn to_google_message(&self) -> Result<MessageNum, TypeError> {
332 match self {
333 MessageNum::OpenAIMessageV1(msg) => {
334 Ok(MessageNum::GeminiContentV1(msg.to_google_message()?))
335 }
336 MessageNum::AnthropicMessageV1(msg) => {
337 Ok(MessageNum::GeminiContentV1(msg.to_google_message()?))
338 }
339 _ => Err(TypeError::CantConvertSelf),
340 }
341 }
342
343 fn convert_message_to_provider_type(
344 &self,
345 provider: &Provider,
346 ) -> Result<MessageNum, TypeError> {
347 match provider {
348 Provider::OpenAI => self.to_openai_message(),
349 Provider::Anthropic => self.to_anthropic_message(),
350 Provider::Google => self.to_google_message(),
351 Provider::Vertex => self.to_google_message(),
352 Provider::Gemini => self.to_google_message(),
353 _ => Err(TypeError::UnsupportedProviderError),
354 }
355 }
356
357 pub fn convert_message(&mut self, provider: &Provider) -> Result<(), TypeError> {
358 if matches!(self, MessageNum::RawV1(_)) {
360 return Ok(());
361 }
362 if self.matches_provider(provider) {
364 return Ok(());
365 }
366 let converted = self.convert_message_to_provider_type(provider)?;
367 *self = converted;
368 Ok(())
369 }
370
371 pub fn anthropic_message_to_system_message(&mut self) -> Result<(), TypeError> {
372 match self {
373 MessageNum::AnthropicMessageV1(msg) => {
374 let text_param = msg.to_text_block_param()?;
375 *self = MessageNum::AnthropicSystemMessageV1(text_param);
376 Ok(())
377 }
378 _ => Err(TypeError::Error(
379 "Cannot convert non-AnthropicMessageV1 to system message".to_string(),
380 )),
381 }
382 }
383 pub fn role(&self) -> &str {
384 match self {
385 MessageNum::OpenAIMessageV1(msg) => &msg.role,
386 MessageNum::AnthropicMessageV1(msg) => &msg.role,
387 MessageNum::GeminiContentV1(msg) => &msg.role,
388 _ => "system",
389 }
390 }
391 pub fn bind(&self, name: &str, value: &str) -> Result<Self, TypeError> {
392 match self {
393 MessageNum::OpenAIMessageV1(msg) => {
394 let bound_msg = msg.bind(name, value)?;
395 Ok(MessageNum::OpenAIMessageV1(bound_msg))
396 }
397 MessageNum::AnthropicMessageV1(msg) => {
398 let bound_msg = msg.bind(name, value)?;
399 Ok(MessageNum::AnthropicMessageV1(bound_msg))
400 }
401 MessageNum::GeminiContentV1(msg) => {
402 let bound_msg = msg.bind(name, value)?;
403 Ok(MessageNum::GeminiContentV1(bound_msg))
404 }
405 _ => Ok(self.clone()),
406 }
407 }
408 pub fn bind_mut(&mut self, name: &str, value: &str) -> Result<(), TypeError> {
409 match self {
410 MessageNum::OpenAIMessageV1(msg) => msg.bind_mut(name, value),
411 MessageNum::AnthropicMessageV1(msg) => msg.bind_mut(name, value),
412 MessageNum::GeminiContentV1(msg) => msg.bind_mut(name, value),
413 _ => Ok(()),
414 }
415 }
416
417 pub(crate) fn extract_variables(&self) -> Vec<String> {
418 match self {
419 MessageNum::OpenAIMessageV1(msg) => msg.extract_variables(),
420 MessageNum::AnthropicMessageV1(msg) => msg.extract_variables(),
421 MessageNum::GeminiContentV1(msg) => msg.extract_variables(),
422 _ => vec![],
423 }
424 }
425
426 pub fn to_bound_py_object<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyAny>, TypeError> {
427 match self {
428 MessageNum::OpenAIMessageV1(msg) => {
429 let bound_msg = msg.clone().into_bound_py_any(py)?;
430 Ok(bound_msg)
431 }
432 MessageNum::AnthropicMessageV1(msg) => {
433 let bound_msg = msg.clone().into_bound_py_any(py)?;
434 Ok(bound_msg)
435 }
436 MessageNum::GeminiContentV1(msg) => {
437 let bound_msg = msg.clone().into_bound_py_any(py)?;
438 Ok(bound_msg)
439 }
440 MessageNum::AnthropicSystemMessageV1(msg) => {
441 let anthropic_msg = AnthropicMessage::from_text(msg.text.clone(), self.role())?;
443 let bound_msg = anthropic_msg.into_bound_py_any(py)?;
444 Ok(bound_msg)
445 }
446 MessageNum::RawV1(v) => Ok(pythonize(py, v)?),
447 }
448 }
449
450 pub fn to_bound_openai_message<'py>(
451 &self,
452 py: Python<'py>,
453 ) -> Result<Bound<'py, OpenAIChatMessage>, TypeError> {
454 match self {
455 MessageNum::OpenAIMessageV1(msg) => {
456 let py_obj = Py::new(py, msg.clone())?;
457 let bound = py_obj.bind(py);
458 Ok(bound.clone())
459 }
460 _ => Err(TypeError::CantConvertSelf),
461 }
462 }
463
464 pub fn to_bound_gemini_message<'py>(
465 &self,
466 py: Python<'py>,
467 ) -> Result<Bound<'py, GeminiContent>, TypeError> {
468 match self {
469 MessageNum::GeminiContentV1(msg) => {
470 let py_obj = Py::new(py, msg.clone())?;
471 let bound = py_obj.bind(py);
472 Ok(bound.clone())
473 }
474 _ => Err(TypeError::CantConvertSelf),
475 }
476 }
477
478 pub fn to_bound_anthropic_message<'py>(
479 &self,
480 py: Python<'py>,
481 ) -> Result<Bound<'py, AnthropicMessage>, TypeError> {
482 match self {
483 MessageNum::AnthropicMessageV1(msg) => {
484 let py_obj = Py::new(py, msg.clone())?;
485 let bound = py_obj.bind(py);
486 Ok(bound.clone())
487 }
488 _ => Err(TypeError::CantConvertSelf),
489 }
490 }
491
492 pub fn is_system_message(&self) -> bool {
493 match self {
494 MessageNum::OpenAIMessageV1(msg) => {
495 msg.role == Role::Developer.to_string() || msg.role == Role::System.to_string()
496 }
497 MessageNum::AnthropicMessageV1(msg) => msg.role == Role::System.to_string(),
498 MessageNum::GeminiContentV1(msg) => msg.role == Role::Model.to_string(),
499 MessageNum::AnthropicSystemMessageV1(_) => true,
500 MessageNum::RawV1(v) => v
501 .get("role")
502 .and_then(|r| r.as_str())
503 .map(|r| r == "system" || r == "developer")
504 .unwrap_or(false),
505 }
506 }
507
508 pub fn is_user_message(&self) -> bool {
509 match self {
510 MessageNum::OpenAIMessageV1(msg) => msg.role == Role::User.to_string(),
511 MessageNum::AnthropicMessageV1(msg) => msg.role == Role::User.to_string(),
512 MessageNum::GeminiContentV1(msg) => msg.role == Role::User.to_string(),
513 MessageNum::AnthropicSystemMessageV1(_) => false,
514 MessageNum::RawV1(v) => v
515 .get("role")
516 .and_then(|r| r.as_str())
517 .map(|r| r == "user")
518 .unwrap_or(false),
519 }
520 }
521}
522
523#[derive(Debug, Clone)]
524pub enum ResponseContent {
525 OpenAI(Choice),
526 Google(Candidate),
527 Anthropic(ResponseContentBlock),
528 PredictResponse(PredictResponse),
529}
530
531#[pyclass]
532pub struct OpenAIMessageList {
533 pub messages: Vec<OpenAIChatMessage>,
534}
535
536#[pymethods]
537impl OpenAIMessageList {
538 fn __iter__(slf: PyRef<'_, Self>) -> PyResult<Py<OpenAIMessageIterator>> {
539 let iter = OpenAIMessageIterator {
540 inner: slf.messages.clone().into_iter(),
541 };
542 Py::new(slf.py(), iter)
543 }
544
545 pub fn __len__(&self) -> usize {
546 self.messages.len()
547 }
548
549 pub fn __getitem__(&self, index: isize) -> Result<OpenAIChatMessage, TypeError> {
550 let len = self.messages.len() as isize;
551 let normalized_index = if index < 0 { len + index } else { index };
552
553 if normalized_index < 0 || normalized_index >= len {
554 return Err(TypeError::Error(format!(
555 "Index {} out of range for list of length {}",
556 index, len
557 )));
558 }
559
560 Ok(self.messages[normalized_index as usize].clone())
561 }
562
563 pub fn __str__(&self) -> String {
564 PyHelperFuncs::__str__(&self.messages)
565 }
566
567 pub fn __repr__(&self) -> String {
568 self.__str__()
569 }
570}
571
572#[pyclass]
573pub struct OpenAIMessageIterator {
574 inner: std::vec::IntoIter<OpenAIChatMessage>,
575}
576
577#[pymethods]
578impl OpenAIMessageIterator {
579 fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> {
580 slf
581 }
582
583 fn __next__(mut slf: PyRefMut<'_, Self>) -> Option<OpenAIChatMessage> {
584 slf.inner.next()
585 }
586}
587
588#[pyclass]
589pub struct AnthropicMessageList {
590 pub messages: Vec<AnthropicMessage>,
591}
592
593#[pymethods]
594impl AnthropicMessageList {
595 fn __iter__(slf: PyRef<'_, Self>) -> PyResult<Py<AnthropicMessageIterator>> {
596 let iter = AnthropicMessageIterator {
597 inner: slf.messages.clone().into_iter(),
598 };
599 Py::new(slf.py(), iter)
600 }
601
602 pub fn __len__(&self) -> usize {
603 self.messages.len()
604 }
605
606 pub fn __getitem__(&self, index: isize) -> Result<AnthropicMessage, TypeError> {
607 let len = self.messages.len() as isize;
608 let normalized_index = if index < 0 { len + index } else { index };
609
610 if normalized_index < 0 || normalized_index >= len {
611 return Err(TypeError::Error(format!(
612 "Index {} out of range for list of length {}",
613 index, len
614 )));
615 }
616
617 Ok(self.messages[normalized_index as usize].clone())
618 }
619
620 pub fn __str__(&self) -> String {
621 PyHelperFuncs::__str__(&self.messages)
622 }
623
624 pub fn __repr__(&self) -> String {
625 self.__str__()
626 }
627}
628
629#[pyclass]
630pub struct AnthropicMessageIterator {
631 inner: std::vec::IntoIter<AnthropicMessage>,
632}
633
634#[pymethods]
635impl AnthropicMessageIterator {
636 fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> {
637 slf
638 }
639
640 fn __next__(mut slf: PyRefMut<'_, Self>) -> Option<AnthropicMessage> {
641 slf.inner.next()
642 }
643}
644
645#[pyclass]
646pub struct GeminiContentList {
647 pub messages: Vec<GeminiContent>,
648}
649
650#[pymethods]
651impl GeminiContentList {
652 fn __iter__(slf: PyRef<'_, Self>) -> PyResult<Py<GeminiContentIterator>> {
653 let iter = GeminiContentIterator {
654 inner: slf.messages.clone().into_iter(),
655 };
656 Py::new(slf.py(), iter)
657 }
658
659 pub fn __len__(&self) -> usize {
660 self.messages.len()
661 }
662
663 pub fn __getitem__(&self, index: isize) -> Result<GeminiContent, TypeError> {
664 let len = self.messages.len() as isize;
665 let normalized_index = if index < 0 { len + index } else { index };
666
667 if normalized_index < 0 || normalized_index >= len {
668 return Err(TypeError::Error(format!(
669 "Index {} out of range for list of length {}",
670 index, len
671 )));
672 }
673
674 Ok(self.messages[normalized_index as usize].clone())
675 }
676
677 pub fn __str__(&self) -> String {
678 PyHelperFuncs::__str__(&self.messages)
679 }
680
681 pub fn __repr__(&self) -> String {
682 self.__str__()
683 }
684}
685
686#[pyclass]
687pub struct GeminiContentIterator {
688 inner: std::vec::IntoIter<GeminiContent>,
689}
690
691#[pymethods]
692impl GeminiContentIterator {
693 fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> {
694 slf
695 }
696
697 fn __next__(mut slf: PyRefMut<'_, Self>) -> Option<GeminiContent> {
698 slf.inner.next()
699 }
700}