potato_type/traits.rs
1use crate::{
2 error::TypeError,
3 prompt::{MessageNum, ModelSettings, ResponseContent},
4 tools::AgentToolDefinition,
5 Provider,
6};
7use potato_util::utils::TokenLogProbs;
8use pyo3::prelude::*;
9use pyo3::types::PyList;
10use regex::Regex;
11use serde::{Deserialize, Serialize};
12use serde_json::Value;
13use std::sync::OnceLock;
14
15pub static VAR_REGEX: OnceLock<Regex> = OnceLock::new();
16pub fn get_var_regex() -> &'static Regex {
17 VAR_REGEX.get_or_init(|| Regex::new(r"\$?\{([a-zA-Z_][a-zA-Z0-9_]*)\}").unwrap())
18}
19use crate::prompt::builder::ProviderRequest;
20/// Core trait that all message types must implement
21pub trait PromptMessageExt:
22 Send + Sync + Clone + Serialize + for<'de> Deserialize<'de> + PartialEq
23{
24 /// Bind a variable in the message content, returning a new instance
25 fn bind(&self, name: &str, value: &str) -> Result<Self, TypeError>
26 where
27 Self: Sized;
28
29 /// Bind a variable in-place
30 fn bind_mut(&mut self, name: &str, value: &str) -> Result<(), TypeError>;
31
32 /// Extract variables from the message content
33 fn extract_variables(&self) -> Vec<String>;
34
35 fn from_text(content: String, role: &str) -> Result<Self, TypeError>;
36}
37
38/// Core trait that must be implemented for all request types
39pub trait RequestAdapter {
40 /// Returns all messages in the request
41 fn messages(&self) -> &[MessageNum];
42 /// Returns a mutable reference to the messages in the request
43 fn messages_mut(&mut self) -> &mut Vec<MessageNum>;
44 /// Returns all system instructions in the request
45 fn system_instructions(&self) -> Vec<&MessageNum>;
46 /// Returns the response JSON schema if set
47 fn response_json_schema(&self) -> Option<&Value>;
48 /// Inserts a message at the specified index (or at the start if None)
49 fn insert_message(&mut self, message: MessageNum, idx: Option<usize>) {
50 self.messages_mut().insert(idx.unwrap_or(0), message);
51 }
52 /// Prepends system instructions to the messages
53 fn preprend_system_instructions(&mut self, messages: Vec<MessageNum>) -> Result<(), TypeError>;
54
55 /// Returns the system instructions as a Python list
56 /// # Arguments
57 /// * `py` - The Python GIL token
58 /// # Returns
59 /// Returns a Python list of system instruction messages
60 fn get_py_system_instructions<'py>(
61 &self,
62 py: Python<'py>,
63 ) -> Result<Bound<'py, PyList>, TypeError>;
64 /// Returns the model settings for the request (python object)
65 fn model_settings<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyAny>, TypeError>;
66
67 /// Converts the request to a JSON value for sending to the provider
68 fn to_request_body(&self) -> Result<Value, TypeError>;
69 /// Checks if the request matches the given provider
70 fn match_provider(&self, provider: &Provider) -> bool;
71 /// Builds a provider-specific request enum from the given parameters
72 /// The ProviderRequest enum encapsulates all supported provider request types and is an
73 /// attribute of the Prompt struct. ProviderRequest is built on instantiation of the Prompt
74 fn build_provider_enum(
75 messages: Vec<MessageNum>,
76 system_instructions: Vec<MessageNum>,
77 model: String,
78 settings: ModelSettings,
79 response_json_schema: Option<Value>,
80 ) -> Result<ProviderRequest, TypeError>;
81
82 /// Sets the response JSON schema for the request
83 /// Typically used as part of workflows when adding tasks
84 fn set_response_json_schema(&mut self, response_json_schema: Option<Value>) -> ();
85
86 /// Adds tools to the request
87 fn add_tools(&mut self, tools: Vec<AgentToolDefinition>) -> Result<(), TypeError>;
88}
89
90pub trait ResponseAdapter {
91 /// Returns a string representation of the response
92 fn __str__(&self) -> String;
93
94 /// Checks if the response is empty
95 fn is_empty(&self) -> bool;
96
97 /// Converts the response to a Python object
98 fn to_bound_py_object<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyAny>, TypeError>;
99
100 /// Returns the response ID
101 fn id(&self) -> &str;
102
103 /// Converts the response to a vector of MessageNum
104 fn to_message_num(&self) -> Result<Vec<MessageNum>, TypeError>;
105
106 // Get the token usage as a Python object
107 fn usage<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyAny>, TypeError>;
108
109 /// Retrieves the first content choice from the response
110 fn get_content(&self) -> ResponseContent;
111
112 /// Retrieves the log probabilities from the response
113 fn get_log_probs(&self) -> Vec<TokenLogProbs>;
114
115 /// Returns the structured output of the response
116 /// For all response types the flow is as follows:
117 /// 1. Check if the response has content (string/text)
118 /// 2. If no content, return Python None
119 /// 3. If content exists, check if an output_type/model is provided
120 /// 4. If output_type/model is provided, attempt to convert the content to that type
121 /// 5. If conversion fails, attempt to construct a generic Python object from the content
122 /// 6. If no output_type/model is provided, return the content as a generic Python object
123 /// # Arguments
124 /// * `py`: The Python GIL token
125 /// * `output_type`: An optional Python type/model to convert the content into. This can be a pydantic model or any object
126 /// that implements model_validate_json that can parse from a JSON string.
127 /// # Returns
128 /// * `Result<Bound<'py, PyAny>, TypeError>`: The structured output as a Python object or an error
129 fn structured_output<'py>(
130 &self,
131 py: Python<'py>,
132 output_type: Option<&Bound<'py, PyAny>>,
133 ) -> Result<Bound<'py, PyAny>, TypeError>;
134
135 /// Returns the structured output value as a serde_json::Value
136 fn structured_output_value(&self) -> Option<Value>;
137
138 /// Returns any tool calls made in the response, if applicable
139 fn tool_call_output(&self) -> Option<Value>;
140
141 /// Returns the output text of the response if available
142 fn response_text(&self) -> String;
143}
144
145pub trait MessageResponseExt {
146 fn to_message_num(&self) -> Result<MessageNum, TypeError>;
147}
148
149pub trait MessageFactory: Sized {
150 fn from_text(content: String, role: &str) -> Result<Self, TypeError>;
151}
152
153/// Trait for converting between different provider message formats
154///
155/// This trait enables conversion of messages between different LLM provider formats
156/// (e.g., Anthropic MessageParam ↔ Google GeminiContent ↔ OpenAI ChatMessage).
157///
158/// Currently focused on text content conversion, with support for other content
159/// types planned for future implementation.
160pub trait MessageConversion {
161 /// Convert this message to an Anthropic MessageParam
162 ///
163 /// # Errors
164 /// Returns `TypeError::UnsupportedConversion` if the message contains
165 /// content types that cannot be represented in Anthropic's format
166 fn to_anthropic_message(
167 &self,
168 ) -> Result<crate::anthropic::v1::request::MessageParam, TypeError>;
169
170 /// Convert this message to a Google GeminiContent
171 ///
172 /// # Errors
173 /// Returns `TypeError::UnsupportedConversion` if the message contains
174 /// content types that cannot be represented in Google's format
175 fn to_google_message(
176 &self,
177 ) -> Result<crate::google::v1::generate::request::GeminiContent, TypeError>;
178
179 /// Convert this message to an OpenAI ChatMessage
180 ///
181 /// # Errors
182 /// Returns `TypeError::UnsupportedConversion` if the message contains
183 /// content types that cannot be represented in OpenAI's format
184 fn to_openai_message(&self)
185 -> Result<crate::openai::v1::chat::request::ChatMessage, TypeError>;
186}