1use crate::anthropic::v1::request::MessageParam as AnthropicMessage;
2use crate::anthropic::v1::request::TextBlockParam;
3use crate::anthropic::v1::response::ResponseContentBlock;
4use crate::google::v1::generate::Candidate;
5use crate::google::v1::generate::GeminiContent;
6use crate::google::PredictResponse;
7use crate::openai::v1::chat::request::ChatMessage as OpenAIChatMessage;
8use crate::openai::v1::Choice;
9use crate::traits::MessageConversion;
10use crate::traits::PromptMessageExt;
11use crate::Provider;
12use crate::{StructuredOutput, TypeError};
13use potato_util::PyHelperFuncs;
14use pyo3::types::PyAnyMethods;
15use pyo3::{prelude::*, IntoPyObjectExt};
16use pythonize::{depythonize, pythonize};
17use schemars::JsonSchema;
18use serde::{Deserialize, Serialize};
19use serde_json::Value;
20use std::fmt::Display;
21use tracing::error;
22#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
23#[pyclass]
24pub enum Role {
25 User,
26 Assistant,
27 Developer,
28 Tool,
29 Model,
30 System,
31}
32
33#[pymethods]
34impl Role {
35 #[pyo3(name = "as_str")]
36 pub fn as_str_py(&self) -> &'static str {
37 self.as_str()
38 }
39}
40
41impl Role {
42 pub const fn as_str(&self) -> &'static str {
44 match self {
45 Role::User => "user",
46 Role::Assistant => "assistant",
47 Role::Developer => "developer",
48 Role::Tool => "tool",
49 Role::Model => "model",
50 Role::System => "system",
51 }
52 }
53}
54
55impl Display for Role {
56 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
57 match self {
58 Role::User => write!(f, "user"),
59 Role::Assistant => write!(f, "assistant"),
60 Role::Developer => write!(f, "developer"),
61 Role::Tool => write!(f, "tool"),
62 Role::Model => write!(f, "model"),
63 Role::System => write!(f, "system"),
64 }
65 }
66}
67
68impl From<Role> for &str {
69 fn from(role: Role) -> Self {
70 match role {
71 Role::User => "user",
72 Role::Assistant => "assistant",
73 Role::Developer => "developer",
74 Role::Tool => "tool",
75 Role::Model => "model",
76 Role::System => "system",
77 }
78 }
79}
80
81pub trait DeserializePromptValExt: for<'de> serde::Deserialize<'de> {
82 fn model_validate_json(value: &Value) -> Result<Self, serde_json::Error> {
90 serde_json::from_value(value.clone())
91 }
92}
93
94pub fn get_pydantic_module<'py>(py: Python<'py>, module_name: &str) -> PyResult<Bound<'py, PyAny>> {
95 py.import("pydantic_ai")?.getattr(module_name)
96}
97
98pub fn check_pydantic_model<'py>(
105 py: Python<'py>,
106 object: &Bound<'_, PyAny>,
107) -> Result<bool, TypeError> {
108 let pydantic = match py.import("pydantic").map_err(|e| {
110 error!("Failed to import pydantic: {}", e);
111 false
112 }) {
113 Ok(pydantic) => pydantic,
114 Err(_) => return Ok(false),
115 };
116
117 let is_subclass = py.import("builtins")?.getattr("issubclass")?;
119
120 let basemodel = pydantic.getattr("BaseModel")?;
122 let matched = is_subclass.call1((object, basemodel))?.extract::<bool>()?;
123
124 Ok(matched)
125}
126
127fn get_json_schema_from_basemodel(object: &Bound<'_, PyAny>) -> Result<Value, TypeError> {
133 let schema = object.getattr("model_json_schema")?.call1(())?;
135
136 let mut schema: Value = depythonize(&schema)?;
137
138 if let Some(additional_properties) = schema.get_mut("additionalProperties") {
140 *additional_properties = serde_json::json!(false);
141 } else {
142 schema
143 .as_object_mut()
144 .unwrap()
145 .insert("additionalProperties".to_string(), serde_json::json!(false));
146 }
147
148 Ok(schema)
149}
150
151fn parse_pydantic_model<'py>(
152 py: Python<'py>,
153 object: &Bound<'_, PyAny>,
154) -> Result<Option<Value>, TypeError> {
155 let is_subclass = check_pydantic_model(py, object)?;
156 if is_subclass {
157 Ok(Some(get_json_schema_from_basemodel(object)?))
158 } else {
159 Ok(None)
160 }
161}
162
163pub fn check_response_type(object: &Bound<'_, PyAny>) -> Result<Option<ResponseType>, TypeError> {
164 let response_type = match object.getattr("response_type") {
166 Ok(method) => {
167 if method.is_callable() {
168 let response_type: ResponseType = method.call0()?.extract()?;
169 Some(response_type)
170 } else {
171 None
172 }
173 }
174 Err(_) => None,
175 };
176
177 Ok(response_type)
178}
179
180fn get_json_schema_from_response_type(response_type: &ResponseType) -> Result<Value, TypeError> {
181 match response_type {
182 ResponseType::Score => Ok(Score::get_structured_output_schema()),
183 _ => {
184 Err(TypeError::Error(format!(
186 "Unsupported response type: {response_type}"
187 )))
188 }
189 }
190}
191
192pub fn parse_response_to_json<'py>(
193 py: Python<'py>,
194 object: &Bound<'_, PyAny>,
195) -> Result<(ResponseType, Option<Value>), TypeError> {
196 let is_pydantic_model = check_pydantic_model(py, object)?;
198 if is_pydantic_model {
199 return Ok((ResponseType::Pydantic, parse_pydantic_model(py, object)?));
200 }
201
202 let response_type = check_response_type(object)?;
204 if let Some(response_type) = response_type {
205 return Ok((
206 response_type.clone(),
207 Some(get_json_schema_from_response_type(&response_type)?),
208 ));
209 }
210
211 Ok((ResponseType::Null, None))
212}
213
214#[pyclass]
215#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
216#[serde(deny_unknown_fields)] pub struct Score {
218 #[pyo3(get)]
219 #[schemars(range(min = 1, max = 5))]
220 pub score: i64,
221
222 #[pyo3(get)]
223 pub reason: String,
224}
225#[pymethods]
226impl Score {
227 #[staticmethod]
228 pub fn response_type() -> ResponseType {
229 ResponseType::Score
230 }
231
232 #[staticmethod]
233 pub fn model_validate_json(json_string: String) -> Result<Score, TypeError> {
234 Ok(serde_json::from_str(&json_string)?)
235 }
236
237 #[staticmethod]
238 pub fn model_json_schema<'py>(py: Python<'py>) -> Result<Bound<'py, PyAny>, TypeError> {
239 let schema = Score::get_structured_output_schema();
240 Ok(pythonize(py, &schema)?)
241 }
242
243 pub fn __str__(&self) -> String {
244 PyHelperFuncs::__str__(self)
245 }
246}
247
248impl StructuredOutput for Score {}
249
250#[pyclass]
251#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
252pub enum ResponseType {
253 Score,
254 Pydantic,
255 #[default]
256 Null, }
258
259impl Display for ResponseType {
260 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
261 match self {
262 ResponseType::Score => write!(f, "Score"),
263 ResponseType::Pydantic => write!(f, "Pydantic"),
264 ResponseType::Null => write!(f, "Null"),
265 }
266 }
267}
268
269#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
272#[serde(untagged)]
273pub enum MessageNum {
274 OpenAIMessageV1(OpenAIChatMessage),
275 AnthropicMessageV1(AnthropicMessage),
276 GeminiContentV1(GeminiContent),
277
278 AnthropicSystemMessageV1(TextBlockParam),
280}
281
282impl MessageNum {
283 fn matches_provider(&self, provider: &Provider) -> bool {
285 matches!(
286 (self, provider),
287 (MessageNum::OpenAIMessageV1(_), Provider::OpenAI)
288 | (MessageNum::AnthropicMessageV1(_), Provider::Anthropic)
289 | (MessageNum::AnthropicSystemMessageV1(_), Provider::Anthropic)
290 | (MessageNum::GeminiContentV1(_), Provider::Google)
291 | (MessageNum::GeminiContentV1(_), Provider::Vertex)
292 | (MessageNum::GeminiContentV1(_), Provider::Gemini)
293 )
294 }
295
296 fn to_openai_message(&self) -> Result<MessageNum, TypeError> {
301 match self {
302 MessageNum::AnthropicMessageV1(msg) => {
303 Ok(MessageNum::OpenAIMessageV1(msg.to_openai_message()?))
304 }
305 MessageNum::GeminiContentV1(msg) => {
306 Ok(MessageNum::OpenAIMessageV1(msg.to_openai_message()?))
307 }
308 _ => Err(TypeError::CantConvertSelf),
309 }
310 }
311
312 fn to_anthropic_message(&self) -> Result<MessageNum, TypeError> {
314 match self {
315 MessageNum::OpenAIMessageV1(msg) => {
316 Ok(MessageNum::AnthropicMessageV1(msg.to_anthropic_message()?))
317 }
318 MessageNum::GeminiContentV1(msg) => {
319 Ok(MessageNum::AnthropicMessageV1(msg.to_anthropic_message()?))
320 }
321 _ => Err(TypeError::CantConvertSelf),
322 }
323 }
324
325 fn to_google_message(&self) -> Result<MessageNum, TypeError> {
327 match self {
328 MessageNum::OpenAIMessageV1(msg) => {
329 Ok(MessageNum::GeminiContentV1(msg.to_google_message()?))
330 }
331 MessageNum::AnthropicMessageV1(msg) => {
332 Ok(MessageNum::GeminiContentV1(msg.to_google_message()?))
333 }
334 _ => Err(TypeError::CantConvertSelf),
335 }
336 }
337
338 fn convert_message_to_provider_type(
339 &self,
340 provider: &Provider,
341 ) -> Result<MessageNum, TypeError> {
342 match provider {
343 Provider::OpenAI => self.to_openai_message(),
344 Provider::Anthropic => self.to_anthropic_message(),
345 Provider::Google => self.to_google_message(),
346 Provider::Vertex => self.to_google_message(),
347 Provider::Gemini => self.to_google_message(),
348 _ => Err(TypeError::UnsupportedProviderError),
349 }
350 }
351
352 pub fn convert_message(&mut self, provider: &Provider) -> Result<(), TypeError> {
353 if self.matches_provider(provider) {
355 return Ok(());
356 }
357 let converted = self.convert_message_to_provider_type(provider)?;
358 *self = converted;
359 Ok(())
360 }
361
362 pub fn anthropic_message_to_system_message(&mut self) -> Result<(), TypeError> {
363 match self {
364 MessageNum::AnthropicMessageV1(msg) => {
365 let text_param = msg.to_text_block_param()?;
366 *self = MessageNum::AnthropicSystemMessageV1(text_param);
367 Ok(())
368 }
369 _ => Err(TypeError::Error(
370 "Cannot convert non-AnthropicMessageV1 to system message".to_string(),
371 )),
372 }
373 }
374 pub fn role(&self) -> &str {
375 match self {
376 MessageNum::OpenAIMessageV1(msg) => &msg.role,
377 MessageNum::AnthropicMessageV1(msg) => &msg.role,
378 MessageNum::GeminiContentV1(msg) => &msg.role,
379 _ => "system",
380 }
381 }
382 pub fn bind(&self, name: &str, value: &str) -> Result<Self, TypeError> {
383 match self {
384 MessageNum::OpenAIMessageV1(msg) => {
385 let bound_msg = msg.bind(name, value)?;
386 Ok(MessageNum::OpenAIMessageV1(bound_msg))
387 }
388 MessageNum::AnthropicMessageV1(msg) => {
389 let bound_msg = msg.bind(name, value)?;
390 Ok(MessageNum::AnthropicMessageV1(bound_msg))
391 }
392 MessageNum::GeminiContentV1(msg) => {
393 let bound_msg = msg.bind(name, value)?;
394 Ok(MessageNum::GeminiContentV1(bound_msg))
395 }
396 _ => Ok(self.clone()),
397 }
398 }
399 pub fn bind_mut(&mut self, name: &str, value: &str) -> Result<(), TypeError> {
400 match self {
401 MessageNum::OpenAIMessageV1(msg) => msg.bind_mut(name, value),
402 MessageNum::AnthropicMessageV1(msg) => msg.bind_mut(name, value),
403 MessageNum::GeminiContentV1(msg) => msg.bind_mut(name, value),
404 _ => Ok(()),
405 }
406 }
407
408 pub(crate) fn extract_variables(&self) -> Vec<String> {
409 match self {
410 MessageNum::OpenAIMessageV1(msg) => msg.extract_variables(),
411 MessageNum::AnthropicMessageV1(msg) => msg.extract_variables(),
412 MessageNum::GeminiContentV1(msg) => msg.extract_variables(),
413 _ => vec![],
414 }
415 }
416
417 pub fn to_bound_py_object<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyAny>, TypeError> {
418 match self {
419 MessageNum::OpenAIMessageV1(msg) => {
420 let bound_msg = msg.clone().into_bound_py_any(py)?;
421 Ok(bound_msg)
422 }
423 MessageNum::AnthropicMessageV1(msg) => {
424 let bound_msg = msg.clone().into_bound_py_any(py)?;
425 Ok(bound_msg)
426 }
427 MessageNum::GeminiContentV1(msg) => {
428 let bound_msg = msg.clone().into_bound_py_any(py)?;
429 Ok(bound_msg)
430 }
431 MessageNum::AnthropicSystemMessageV1(msg) => {
432 let anthropic_msg = AnthropicMessage::from_text(msg.text.clone(), self.role())?;
434 let bound_msg = anthropic_msg.into_bound_py_any(py)?;
435 Ok(bound_msg)
436 }
437 }
438 }
439
440 pub fn to_bound_openai_message<'py>(
441 &self,
442 py: Python<'py>,
443 ) -> Result<Bound<'py, OpenAIChatMessage>, TypeError> {
444 match self {
445 MessageNum::OpenAIMessageV1(msg) => {
446 let py_obj = Py::new(py, msg.clone())?;
447 let bound = py_obj.bind(py);
448 Ok(bound.clone())
449 }
450 _ => Err(TypeError::CantConvertSelf),
451 }
452 }
453
454 pub fn to_bound_gemini_message<'py>(
455 &self,
456 py: Python<'py>,
457 ) -> Result<Bound<'py, GeminiContent>, TypeError> {
458 match self {
459 MessageNum::GeminiContentV1(msg) => {
460 let py_obj = Py::new(py, msg.clone())?;
461 let bound = py_obj.bind(py);
462 Ok(bound.clone())
463 }
464 _ => Err(TypeError::CantConvertSelf),
465 }
466 }
467
468 pub fn to_bound_anthropic_message<'py>(
469 &self,
470 py: Python<'py>,
471 ) -> Result<Bound<'py, AnthropicMessage>, TypeError> {
472 match self {
473 MessageNum::AnthropicMessageV1(msg) => {
474 let py_obj = Py::new(py, msg.clone())?;
475 let bound = py_obj.bind(py);
476 Ok(bound.clone())
477 }
478 _ => Err(TypeError::CantConvertSelf),
479 }
480 }
481
482 pub fn is_system_message(&self) -> bool {
483 match self {
484 MessageNum::OpenAIMessageV1(msg) => {
485 msg.role == Role::Developer.to_string() || msg.role == Role::System.to_string()
486 }
487 MessageNum::AnthropicMessageV1(msg) => msg.role == Role::System.to_string(),
488 MessageNum::GeminiContentV1(msg) => msg.role == Role::Model.to_string(),
489 MessageNum::AnthropicSystemMessageV1(_) => true,
490 }
491 }
492
493 pub fn is_user_message(&self) -> bool {
494 match self {
495 MessageNum::OpenAIMessageV1(msg) => msg.role == Role::User.to_string(),
496 MessageNum::AnthropicMessageV1(msg) => msg.role == Role::User.to_string(),
497 MessageNum::GeminiContentV1(msg) => msg.role == Role::User.to_string(),
498 MessageNum::AnthropicSystemMessageV1(_) => false,
499 }
500 }
501}
502
503#[derive(Debug, Clone)]
504pub enum ResponseContent {
505 OpenAI(Choice),
506 Google(Candidate),
507 Anthropic(ResponseContentBlock),
508 PredictResponse(PredictResponse),
509}
510
511#[pyclass]
512pub struct OpenAIMessageList {
513 pub messages: Vec<OpenAIChatMessage>,
514}
515
516#[pymethods]
517impl OpenAIMessageList {
518 fn __iter__(slf: PyRef<'_, Self>) -> PyResult<Py<OpenAIMessageIterator>> {
519 let iter = OpenAIMessageIterator {
520 inner: slf.messages.clone().into_iter(),
521 };
522 Py::new(slf.py(), iter)
523 }
524
525 pub fn __len__(&self) -> usize {
526 self.messages.len()
527 }
528
529 pub fn __getitem__(&self, index: isize) -> Result<OpenAIChatMessage, TypeError> {
530 let len = self.messages.len() as isize;
531 let normalized_index = if index < 0 { len + index } else { index };
532
533 if normalized_index < 0 || normalized_index >= len {
534 return Err(TypeError::Error(format!(
535 "Index {} out of range for list of length {}",
536 index, len
537 )));
538 }
539
540 Ok(self.messages[normalized_index as usize].clone())
541 }
542
543 pub fn __str__(&self) -> String {
544 PyHelperFuncs::__str__(&self.messages)
545 }
546
547 pub fn __repr__(&self) -> String {
548 self.__str__()
549 }
550}
551
552#[pyclass]
553pub struct OpenAIMessageIterator {
554 inner: std::vec::IntoIter<OpenAIChatMessage>,
555}
556
557#[pymethods]
558impl OpenAIMessageIterator {
559 fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> {
560 slf
561 }
562
563 fn __next__(mut slf: PyRefMut<'_, Self>) -> Option<OpenAIChatMessage> {
564 slf.inner.next()
565 }
566}
567
568#[pyclass]
569pub struct AnthropicMessageList {
570 pub messages: Vec<AnthropicMessage>,
571}
572
573#[pymethods]
574impl AnthropicMessageList {
575 fn __iter__(slf: PyRef<'_, Self>) -> PyResult<Py<AnthropicMessageIterator>> {
576 let iter = AnthropicMessageIterator {
577 inner: slf.messages.clone().into_iter(),
578 };
579 Py::new(slf.py(), iter)
580 }
581
582 pub fn __len__(&self) -> usize {
583 self.messages.len()
584 }
585
586 pub fn __getitem__(&self, index: isize) -> Result<AnthropicMessage, TypeError> {
587 let len = self.messages.len() as isize;
588 let normalized_index = if index < 0 { len + index } else { index };
589
590 if normalized_index < 0 || normalized_index >= len {
591 return Err(TypeError::Error(format!(
592 "Index {} out of range for list of length {}",
593 index, len
594 )));
595 }
596
597 Ok(self.messages[normalized_index as usize].clone())
598 }
599
600 pub fn __str__(&self) -> String {
601 PyHelperFuncs::__str__(&self.messages)
602 }
603
604 pub fn __repr__(&self) -> String {
605 self.__str__()
606 }
607}
608
609#[pyclass]
610pub struct AnthropicMessageIterator {
611 inner: std::vec::IntoIter<AnthropicMessage>,
612}
613
614#[pymethods]
615impl AnthropicMessageIterator {
616 fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> {
617 slf
618 }
619
620 fn __next__(mut slf: PyRefMut<'_, Self>) -> Option<AnthropicMessage> {
621 slf.inner.next()
622 }
623}
624
625#[pyclass]
626pub struct GeminiContentList {
627 pub messages: Vec<GeminiContent>,
628}
629
630#[pymethods]
631impl GeminiContentList {
632 fn __iter__(slf: PyRef<'_, Self>) -> PyResult<Py<GeminiContentIterator>> {
633 let iter = GeminiContentIterator {
634 inner: slf.messages.clone().into_iter(),
635 };
636 Py::new(slf.py(), iter)
637 }
638
639 pub fn __len__(&self) -> usize {
640 self.messages.len()
641 }
642
643 pub fn __getitem__(&self, index: isize) -> Result<GeminiContent, TypeError> {
644 let len = self.messages.len() as isize;
645 let normalized_index = if index < 0 { len + index } else { index };
646
647 if normalized_index < 0 || normalized_index >= len {
648 return Err(TypeError::Error(format!(
649 "Index {} out of range for list of length {}",
650 index, len
651 )));
652 }
653
654 Ok(self.messages[normalized_index as usize].clone())
655 }
656
657 pub fn __str__(&self) -> String {
658 PyHelperFuncs::__str__(&self.messages)
659 }
660
661 pub fn __repr__(&self) -> String {
662 self.__str__()
663 }
664}
665
666#[pyclass]
667pub struct GeminiContentIterator {
668 inner: std::vec::IntoIter<GeminiContent>,
669}
670
671#[pymethods]
672impl GeminiContentIterator {
673 fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> {
674 slf
675 }
676
677 fn __next__(mut slf: PyRefMut<'_, Self>) -> Option<GeminiContent> {
678 slf.inner.next()
679 }
680}