1use crate::anthropic::v1::request::MessageParam as AnthropicMessage;
2use crate::anthropic::v1::request::TextBlockParam;
3use crate::anthropic::v1::response::ResponseContentBlock;
4use crate::google::v1::generate::Candidate;
5use crate::google::v1::generate::GeminiContent;
6use crate::google::PredictResponse;
7use crate::openai::v1::chat::request::ChatMessage as OpenAIChatMessage;
8use crate::openai::v1::Choice;
9use crate::traits::MessageConversion;
10use crate::traits::PromptMessageExt;
11use crate::Provider;
12use crate::{StructuredOutput, TypeError};
13use potato_util::PyHelperFuncs;
14use pyo3::types::PyAnyMethods;
15use pyo3::{prelude::*, IntoPyObjectExt};
16use pythonize::{depythonize, pythonize};
17use schemars::JsonSchema;
18use serde::{Deserialize, Serialize};
19use serde_json::Value;
20use std::fmt::Display;
21use tracing::error;
22#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
23#[pyclass]
24pub enum Role {
25 User,
26 Assistant,
27 Developer,
28 Tool,
29 Model,
30 System,
31}
32
33#[pymethods]
34impl Role {
35 #[pyo3(name = "as_str")]
36 pub fn as_str_py(&self) -> &'static str {
37 self.as_str()
38 }
39}
40
41impl Role {
42 pub const fn as_str(&self) -> &'static str {
44 match self {
45 Role::User => "user",
46 Role::Assistant => "assistant",
47 Role::Developer => "developer",
48 Role::Tool => "tool",
49 Role::Model => "model",
50 Role::System => "system",
51 }
52 }
53}
54
55impl Display for Role {
56 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
57 match self {
58 Role::User => write!(f, "user"),
59 Role::Assistant => write!(f, "assistant"),
60 Role::Developer => write!(f, "developer"),
61 Role::Tool => write!(f, "tool"),
62 Role::Model => write!(f, "model"),
63 Role::System => write!(f, "system"),
64 }
65 }
66}
67
68impl From<Role> for &str {
69 fn from(role: Role) -> Self {
70 match role {
71 Role::User => "user",
72 Role::Assistant => "assistant",
73 Role::Developer => "developer",
74 Role::Tool => "tool",
75 Role::Model => "model",
76 Role::System => "system",
77 }
78 }
79}
80
81pub trait DeserializePromptValExt: for<'de> serde::Deserialize<'de> {
82 fn model_validate_json(value: &Value) -> Result<Self, serde_json::Error> {
90 serde_json::from_value(value.clone())
91 }
92}
93
94pub fn get_pydantic_module<'py>(py: Python<'py>, module_name: &str) -> PyResult<Bound<'py, PyAny>> {
95 py.import("pydantic_ai")?.getattr(module_name)
96}
97
98pub fn check_pydantic_model<'py>(
105 py: Python<'py>,
106 object: &Bound<'_, PyAny>,
107) -> Result<bool, TypeError> {
108 let pydantic = match py.import("pydantic").map_err(|e| {
110 error!("Failed to import pydantic: {}", e);
111 false
112 }) {
113 Ok(pydantic) => pydantic,
114 Err(_) => return Ok(false),
115 };
116
117 let is_subclass = py.import("builtins")?.getattr("issubclass")?;
119
120 let basemodel = pydantic.getattr("BaseModel")?;
122 let matched = is_subclass.call1((object, basemodel))?.extract::<bool>()?;
123
124 Ok(matched)
125}
126
127fn get_json_schema_from_basemodel(object: &Bound<'_, PyAny>) -> Result<Value, TypeError> {
133 let schema = object.getattr("model_json_schema")?.call1(())?;
135
136 let mut schema: Value = depythonize(&schema)?;
137
138 if let Some(additional_properties) = schema.get_mut("additionalProperties") {
140 *additional_properties = serde_json::json!(false);
141 } else {
142 schema
143 .as_object_mut()
144 .unwrap()
145 .insert("additionalProperties".to_string(), serde_json::json!(false));
146 }
147
148 Ok(schema)
149}
150
151fn parse_pydantic_model<'py>(
152 py: Python<'py>,
153 object: &Bound<'_, PyAny>,
154) -> Result<Option<Value>, TypeError> {
155 let is_subclass = check_pydantic_model(py, object)?;
156 if is_subclass {
157 Ok(Some(get_json_schema_from_basemodel(object)?))
158 } else {
159 Ok(None)
160 }
161}
162
163pub fn check_response_type(object: &Bound<'_, PyAny>) -> Result<Option<ResponseType>, TypeError> {
164 let response_type = match object.getattr("response_type") {
166 Ok(method) => {
167 if method.is_callable() {
168 let response_type: ResponseType = method.call0()?.extract()?;
169 Some(response_type)
170 } else {
171 None
172 }
173 }
174 Err(_) => None,
175 };
176
177 Ok(response_type)
178}
179
180fn get_json_schema_from_response_type(response_type: &ResponseType) -> Result<Value, TypeError> {
181 match response_type {
182 ResponseType::Score => Ok(Score::get_structured_output_schema()),
183 _ => {
184 Err(TypeError::Error(format!(
186 "Unsupported response type: {response_type}"
187 )))
188 }
189 }
190}
191
192pub fn parse_response_to_json<'py>(
193 py: Python<'py>,
194 object: &Bound<'_, PyAny>,
195) -> Result<(ResponseType, Option<Value>), TypeError> {
196 let is_pydantic_model = check_pydantic_model(py, object)?;
198 if is_pydantic_model {
199 return Ok((ResponseType::Pydantic, parse_pydantic_model(py, object)?));
200 }
201
202 let response_type = check_response_type(object)?;
204 if let Some(response_type) = response_type {
205 return Ok((
206 response_type.clone(),
207 Some(get_json_schema_from_response_type(&response_type)?),
208 ));
209 }
210
211 Ok((ResponseType::Null, None))
212}
213
214#[pyclass]
215#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
216#[serde(deny_unknown_fields)] pub struct Score {
218 #[pyo3(get)]
219 #[schemars(range(min = 1, max = 5))]
220 pub score: i64,
221
222 #[pyo3(get)]
223 pub reason: String,
224}
225#[pymethods]
226impl Score {
227 #[staticmethod]
228 pub fn response_type() -> ResponseType {
229 ResponseType::Score
230 }
231
232 #[staticmethod]
233 pub fn model_validate_json(json_string: String) -> Result<Score, TypeError> {
234 Ok(serde_json::from_str(&json_string)?)
235 }
236
237 #[staticmethod]
238 pub fn model_json_schema<'py>(py: Python<'py>) -> Result<Bound<'py, PyAny>, TypeError> {
239 let schema = Score::get_structured_output_schema();
240 Ok(pythonize(py, &schema)?)
241 }
242
243 pub fn __str__(&self) -> String {
244 PyHelperFuncs::__str__(self)
245 }
246}
247
248impl StructuredOutput for Score {}
249
250#[pyclass]
251#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
252pub enum ResponseType {
253 Score,
254 Pydantic,
255 Null, }
257
258impl Display for ResponseType {
259 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
260 match self {
261 ResponseType::Score => write!(f, "Score"),
262 ResponseType::Pydantic => write!(f, "Pydantic"),
263 ResponseType::Null => write!(f, "Null"),
264 }
265 }
266}
267
268#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
271#[serde(untagged)]
272pub enum MessageNum {
273 OpenAIMessageV1(OpenAIChatMessage),
274 AnthropicMessageV1(AnthropicMessage),
275 GeminiContentV1(GeminiContent),
276
277 AnthropicSystemMessageV1(TextBlockParam),
279}
280
281impl MessageNum {
282 fn matches_provider(&self, provider: &Provider) -> bool {
284 matches!(
285 (self, provider),
286 (MessageNum::OpenAIMessageV1(_), Provider::OpenAI)
287 | (MessageNum::AnthropicMessageV1(_), Provider::Anthropic)
288 | (MessageNum::AnthropicSystemMessageV1(_), Provider::Anthropic)
289 | (MessageNum::GeminiContentV1(_), Provider::Google)
290 | (MessageNum::GeminiContentV1(_), Provider::Vertex)
291 | (MessageNum::GeminiContentV1(_), Provider::Gemini)
292 )
293 }
294
295 fn to_openai_message(&self) -> Result<MessageNum, TypeError> {
300 match self {
301 MessageNum::AnthropicMessageV1(msg) => {
302 Ok(MessageNum::OpenAIMessageV1(msg.to_openai_message()?))
303 }
304 MessageNum::GeminiContentV1(msg) => {
305 Ok(MessageNum::OpenAIMessageV1(msg.to_openai_message()?))
306 }
307 _ => Err(TypeError::CantConvertSelf),
308 }
309 }
310
311 fn to_anthropic_message(&self) -> Result<MessageNum, TypeError> {
313 match self {
314 MessageNum::OpenAIMessageV1(msg) => {
315 Ok(MessageNum::AnthropicMessageV1(msg.to_anthropic_message()?))
316 }
317 MessageNum::GeminiContentV1(msg) => {
318 Ok(MessageNum::AnthropicMessageV1(msg.to_anthropic_message()?))
319 }
320 _ => Err(TypeError::CantConvertSelf),
321 }
322 }
323
324 fn to_google_message(&self) -> Result<MessageNum, TypeError> {
326 match self {
327 MessageNum::OpenAIMessageV1(msg) => {
328 Ok(MessageNum::GeminiContentV1(msg.to_google_message()?))
329 }
330 MessageNum::AnthropicMessageV1(msg) => {
331 Ok(MessageNum::GeminiContentV1(msg.to_google_message()?))
332 }
333 _ => Err(TypeError::CantConvertSelf),
334 }
335 }
336
337 fn convert_message_to_provider_type(
338 &self,
339 provider: &Provider,
340 ) -> Result<MessageNum, TypeError> {
341 match provider {
342 Provider::OpenAI => self.to_openai_message(),
343 Provider::Anthropic => self.to_anthropic_message(),
344 Provider::Google => self.to_google_message(),
345 Provider::Vertex => self.to_google_message(),
346 Provider::Gemini => self.to_google_message(),
347 _ => Err(TypeError::UnsupportedProviderError),
348 }
349 }
350
351 pub fn convert_message(&mut self, provider: &Provider) -> Result<(), TypeError> {
352 if self.matches_provider(provider) {
354 return Ok(());
355 }
356 let converted = self.convert_message_to_provider_type(provider)?;
357 *self = converted;
358 Ok(())
359 }
360
361 pub fn anthropic_message_to_system_message(&mut self) -> Result<(), TypeError> {
362 match self {
363 MessageNum::AnthropicMessageV1(msg) => {
364 let text_param = msg.to_text_block_param()?;
365 *self = MessageNum::AnthropicSystemMessageV1(text_param);
366 Ok(())
367 }
368 _ => Err(TypeError::Error(
369 "Cannot convert non-AnthropicMessageV1 to system message".to_string(),
370 )),
371 }
372 }
373 pub fn role(&self) -> &str {
374 match self {
375 MessageNum::OpenAIMessageV1(msg) => &msg.role,
376 MessageNum::AnthropicMessageV1(msg) => &msg.role,
377 MessageNum::GeminiContentV1(msg) => &msg.role,
378 _ => "system",
379 }
380 }
381 pub fn bind(&self, name: &str, value: &str) -> Result<Self, TypeError> {
382 match self {
383 MessageNum::OpenAIMessageV1(msg) => {
384 let bound_msg = msg.bind(name, value)?;
385 Ok(MessageNum::OpenAIMessageV1(bound_msg))
386 }
387 MessageNum::AnthropicMessageV1(msg) => {
388 let bound_msg = msg.bind(name, value)?;
389 Ok(MessageNum::AnthropicMessageV1(bound_msg))
390 }
391 MessageNum::GeminiContentV1(msg) => {
392 let bound_msg = msg.bind(name, value)?;
393 Ok(MessageNum::GeminiContentV1(bound_msg))
394 }
395 _ => Ok(self.clone()),
396 }
397 }
398 pub fn bind_mut(&mut self, name: &str, value: &str) -> Result<(), TypeError> {
399 match self {
400 MessageNum::OpenAIMessageV1(msg) => msg.bind_mut(name, value),
401 MessageNum::AnthropicMessageV1(msg) => msg.bind_mut(name, value),
402 MessageNum::GeminiContentV1(msg) => msg.bind_mut(name, value),
403 _ => Ok(()),
404 }
405 }
406
407 pub(crate) fn extract_variables(&self) -> Vec<String> {
408 match self {
409 MessageNum::OpenAIMessageV1(msg) => msg.extract_variables(),
410 MessageNum::AnthropicMessageV1(msg) => msg.extract_variables(),
411 MessageNum::GeminiContentV1(msg) => msg.extract_variables(),
412 _ => vec![],
413 }
414 }
415
416 pub fn to_bound_py_object<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyAny>, TypeError> {
417 match self {
418 MessageNum::OpenAIMessageV1(msg) => {
419 let bound_msg = msg.clone().into_bound_py_any(py)?;
420 Ok(bound_msg)
421 }
422 MessageNum::AnthropicMessageV1(msg) => {
423 let bound_msg = msg.clone().into_bound_py_any(py)?;
424 Ok(bound_msg)
425 }
426 MessageNum::GeminiContentV1(msg) => {
427 let bound_msg = msg.clone().into_bound_py_any(py)?;
428 Ok(bound_msg)
429 }
430 MessageNum::AnthropicSystemMessageV1(msg) => {
431 let anthropic_msg = AnthropicMessage::from_text(msg.text.clone(), self.role())?;
433 let bound_msg = anthropic_msg.into_bound_py_any(py)?;
434 Ok(bound_msg)
435 }
436 }
437 }
438
439 pub fn to_bound_openai_message<'py>(
440 &self,
441 py: Python<'py>,
442 ) -> Result<Bound<'py, OpenAIChatMessage>, TypeError> {
443 match self {
444 MessageNum::OpenAIMessageV1(msg) => {
445 let py_obj = Py::new(py, msg.clone())?;
446 let bound = py_obj.bind(py);
447 Ok(bound.clone())
448 }
449 _ => Err(TypeError::CantConvertSelf),
450 }
451 }
452
453 pub fn to_bound_gemini_message<'py>(
454 &self,
455 py: Python<'py>,
456 ) -> Result<Bound<'py, GeminiContent>, TypeError> {
457 match self {
458 MessageNum::GeminiContentV1(msg) => {
459 let py_obj = Py::new(py, msg.clone())?;
460 let bound = py_obj.bind(py);
461 Ok(bound.clone())
462 }
463 _ => Err(TypeError::CantConvertSelf),
464 }
465 }
466
467 pub fn to_bound_anthropic_message<'py>(
468 &self,
469 py: Python<'py>,
470 ) -> Result<Bound<'py, AnthropicMessage>, TypeError> {
471 match self {
472 MessageNum::AnthropicMessageV1(msg) => {
473 let py_obj = Py::new(py, msg.clone())?;
474 let bound = py_obj.bind(py);
475 Ok(bound.clone())
476 }
477 _ => Err(TypeError::CantConvertSelf),
478 }
479 }
480
481 pub fn is_system_message(&self) -> bool {
482 match self {
483 MessageNum::OpenAIMessageV1(msg) => {
484 msg.role == Role::Developer.to_string() || msg.role == Role::System.to_string()
485 }
486 MessageNum::AnthropicMessageV1(msg) => msg.role == Role::System.to_string(),
487 MessageNum::GeminiContentV1(msg) => msg.role == Role::Model.to_string(),
488 MessageNum::AnthropicSystemMessageV1(_) => true,
489 }
490 }
491
492 pub fn is_user_message(&self) -> bool {
493 match self {
494 MessageNum::OpenAIMessageV1(msg) => msg.role == Role::User.to_string(),
495 MessageNum::AnthropicMessageV1(msg) => msg.role == Role::User.to_string(),
496 MessageNum::GeminiContentV1(msg) => msg.role == Role::User.to_string(),
497 MessageNum::AnthropicSystemMessageV1(_) => false,
498 }
499 }
500}
501
502#[derive(Debug, Clone)]
503pub enum ResponseContent {
504 OpenAI(Choice),
505 Google(Candidate),
506 Anthropic(ResponseContentBlock),
507 PredictResponse(PredictResponse),
508}
509
510#[pyclass]
511pub struct OpenAIMessageList {
512 pub messages: Vec<OpenAIChatMessage>,
513}
514
515#[pymethods]
516impl OpenAIMessageList {
517 fn __iter__(slf: PyRef<'_, Self>) -> PyResult<Py<OpenAIMessageIterator>> {
518 let iter = OpenAIMessageIterator {
519 inner: slf.messages.clone().into_iter(),
520 };
521 Py::new(slf.py(), iter)
522 }
523
524 pub fn __len__(&self) -> usize {
525 self.messages.len()
526 }
527
528 pub fn __getitem__(&self, index: isize) -> Result<OpenAIChatMessage, TypeError> {
529 let len = self.messages.len() as isize;
530 let normalized_index = if index < 0 { len + index } else { index };
531
532 if normalized_index < 0 || normalized_index >= len {
533 return Err(TypeError::Error(format!(
534 "Index {} out of range for list of length {}",
535 index, len
536 )));
537 }
538
539 Ok(self.messages[normalized_index as usize].clone())
540 }
541
542 pub fn __str__(&self) -> String {
543 PyHelperFuncs::__str__(&self.messages)
544 }
545
546 pub fn __repr__(&self) -> String {
547 self.__str__()
548 }
549}
550
551#[pyclass]
552pub struct OpenAIMessageIterator {
553 inner: std::vec::IntoIter<OpenAIChatMessage>,
554}
555
556#[pymethods]
557impl OpenAIMessageIterator {
558 fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> {
559 slf
560 }
561
562 fn __next__(mut slf: PyRefMut<'_, Self>) -> Option<OpenAIChatMessage> {
563 slf.inner.next()
564 }
565}
566
567#[pyclass]
568pub struct AnthropicMessageList {
569 pub messages: Vec<AnthropicMessage>,
570}
571
572#[pymethods]
573impl AnthropicMessageList {
574 fn __iter__(slf: PyRef<'_, Self>) -> PyResult<Py<AnthropicMessageIterator>> {
575 let iter = AnthropicMessageIterator {
576 inner: slf.messages.clone().into_iter(),
577 };
578 Py::new(slf.py(), iter)
579 }
580
581 pub fn __len__(&self) -> usize {
582 self.messages.len()
583 }
584
585 pub fn __getitem__(&self, index: isize) -> Result<AnthropicMessage, TypeError> {
586 let len = self.messages.len() as isize;
587 let normalized_index = if index < 0 { len + index } else { index };
588
589 if normalized_index < 0 || normalized_index >= len {
590 return Err(TypeError::Error(format!(
591 "Index {} out of range for list of length {}",
592 index, len
593 )));
594 }
595
596 Ok(self.messages[normalized_index as usize].clone())
597 }
598
599 pub fn __str__(&self) -> String {
600 PyHelperFuncs::__str__(&self.messages)
601 }
602
603 pub fn __repr__(&self) -> String {
604 self.__str__()
605 }
606}
607
608#[pyclass]
609pub struct AnthropicMessageIterator {
610 inner: std::vec::IntoIter<AnthropicMessage>,
611}
612
613#[pymethods]
614impl AnthropicMessageIterator {
615 fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> {
616 slf
617 }
618
619 fn __next__(mut slf: PyRefMut<'_, Self>) -> Option<AnthropicMessage> {
620 slf.inner.next()
621 }
622}
623
624#[pyclass]
625pub struct GeminiContentList {
626 pub messages: Vec<GeminiContent>,
627}
628
629#[pymethods]
630impl GeminiContentList {
631 fn __iter__(slf: PyRef<'_, Self>) -> PyResult<Py<GeminiContentIterator>> {
632 let iter = GeminiContentIterator {
633 inner: slf.messages.clone().into_iter(),
634 };
635 Py::new(slf.py(), iter)
636 }
637
638 pub fn __len__(&self) -> usize {
639 self.messages.len()
640 }
641
642 pub fn __getitem__(&self, index: isize) -> Result<GeminiContent, TypeError> {
643 let len = self.messages.len() as isize;
644 let normalized_index = if index < 0 { len + index } else { index };
645
646 if normalized_index < 0 || normalized_index >= len {
647 return Err(TypeError::Error(format!(
648 "Index {} out of range for list of length {}",
649 index, len
650 )));
651 }
652
653 Ok(self.messages[normalized_index as usize].clone())
654 }
655
656 pub fn __str__(&self) -> String {
657 PyHelperFuncs::__str__(&self.messages)
658 }
659
660 pub fn __repr__(&self) -> String {
661 self.__str__()
662 }
663}
664
665#[pyclass]
666pub struct GeminiContentIterator {
667 inner: std::vec::IntoIter<GeminiContent>,
668}
669
670#[pymethods]
671impl GeminiContentIterator {
672 fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> {
673 slf
674 }
675
676 fn __next__(mut slf: PyRefMut<'_, Self>) -> Option<GeminiContent> {
677 slf.inner.next()
678 }
679}