Skip to main content

potato_type/prompt/
interface.rs

1use crate::anthropic::v1::request::{AnthropicSettings, MessageParam as AnthropicMessage};
2use crate::error::TypeError;
3use crate::google::v1::generate::request::{GeminiContent, GeminiSettings};
4use crate::openai::v1::chat::request::ChatMessage as OpenAIChatMessage;
5use crate::openai::v1::chat::settings::OpenAIChatSettings;
6use crate::prompt::builder::{to_provider_request, ProviderRequest};
7use crate::prompt::settings::ModelSettings;
8use crate::prompt::types::parse_response_to_json;
9use crate::prompt::types::ResponseType;
10use crate::prompt::types::Role;
11use crate::prompt::{AnthropicMessageList, GeminiContentList, MessageNum, OpenAIMessageList};
12use crate::tools::AgentToolDefinition;
13use crate::traits::MessageFactory;
14use crate::SettingsType;
15use crate::{Provider, SaveName};
16use potato_util::utils::extract_string_value;
17use potato_util::PyHelperFuncs;
18use potatohead_macro::try_extract_message;
19use pyo3::prelude::*;
20use pyo3::types::{PyDict, PyList, PyString, PyTuple};
21use pythonize::pythonize;
22use serde::{Deserialize, Serialize};
23use serde_json::Value;
24use std::collections::HashSet;
25use std::path::PathBuf;
26
27fn create_message_for_provider(
28    content: String,
29    provider: &Provider,
30    role: &str,
31) -> Result<MessageNum, TypeError> {
32    match provider {
33        Provider::OpenAI => {
34            OpenAIChatMessage::from_text(content, role).map(MessageNum::OpenAIMessageV1)
35        }
36        Provider::Anthropic => {
37            AnthropicMessage::from_text(content, role).map(MessageNum::AnthropicMessageV1)
38        }
39        Provider::Gemini | Provider::Google | Provider::Vertex => {
40            GeminiContent::from_text(content, role).map(MessageNum::GeminiContentV1)
41        }
42        _ => Err(TypeError::Error(format!(
43            "Unsupported provider for message creation: {:?}",
44            provider
45        ))),
46    }
47}
48
49fn parse_single_message(
50    message: &Bound<'_, PyAny>,
51    provider: &Provider,
52    default_role: &str,
53) -> Result<MessageNum, TypeError> {
54    // String conversion (most common case)
55    if message.is_instance_of::<PyString>() {
56        let text = message.extract::<String>()?;
57        return create_message_for_provider(text, provider, default_role);
58    }
59
60    // Try each message type using macro
61    try_extract_message!(
62        message,
63        OpenAIChatMessage => MessageNum::OpenAIMessageV1,
64        AnthropicMessage => MessageNum::AnthropicMessageV1,
65        GeminiContent => MessageNum::GeminiContentV1,
66    );
67
68    Err(TypeError::InvalidMessageTypeInList(
69        message.get_type().name()?.to_string(),
70    ))
71}
72
73fn parse_messages(
74    messages: &Bound<'_, PyAny>,
75    provider: &Provider,
76    default_role: &str,
77) -> Result<Vec<MessageNum>, TypeError> {
78    // Single message
79    let mut messages =
80        if !messages.is_instance_of::<PyList>() && !messages.is_instance_of::<PyTuple>() {
81            vec![parse_single_message(messages, provider, default_role)?]
82        } else {
83            // List/tuple of messages
84            messages
85                .try_iter()?
86                .map(|item| {
87                    let item = item?;
88                    parse_single_message(&item, provider, default_role)
89                })
90                .collect::<Result<Vec<_>, _>>()?
91        };
92
93    // Convert Anthropic system messages to TextBlockParam format
94    // optimize this later - maybe
95    if provider == &Provider::Anthropic
96        && (default_role == Role::System.as_str()
97            || default_role == Role::Assistant.as_str()
98            || default_role == Role::Developer.as_str())
99    {
100        for msg in messages.iter_mut() {
101            msg.anthropic_message_to_system_message()?;
102        }
103    }
104
105    Ok(messages)
106}
107
108fn get_system_role(provider: &Provider) -> &'static str {
109    match provider {
110        Provider::OpenAI => Role::Developer.into(),
111        Provider::Gemini | Provider::Vertex | Provider::Google => Role::Model.into(),
112        Provider::Anthropic => Role::System.into(),
113        _ => Role::System.into(),
114    }
115}
116
117/// Helper for extracting system instructions from optional parameter
118pub fn extract_system_instructions(
119    system_instruction: Option<&Bound<'_, PyAny>>,
120    provider: &Provider,
121) -> Result<Option<Vec<MessageNum>>, TypeError> {
122    let system_instructions = if let Some(sys_inst) = system_instruction {
123        Some(parse_messages(
124            sys_inst,
125            provider,
126            get_system_role(provider),
127        )?)
128    } else {
129        None
130    };
131
132    Ok(system_instructions)
133}
134
135#[pyclass]
136#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
137pub struct Prompt {
138    pub request: ProviderRequest,
139
140    #[pyo3(get)]
141    pub model: String,
142
143    #[pyo3(get)]
144    pub provider: Provider,
145
146    pub version: String,
147
148    #[pyo3(get)]
149    #[serde(default)]
150    pub parameters: Vec<String>,
151
152    #[serde(default)]
153    pub response_type: ResponseType,
154}
155
156/// ModelSettings variant based on the type of settings provided.
157fn extract_model_settings(model_settings: &Bound<'_, PyAny>) -> Result<ModelSettings, TypeError> {
158    let settings_type = model_settings
159        .call_method0("settings_type")?
160        .extract::<SettingsType>()?;
161
162    match settings_type {
163        SettingsType::OpenAIChat => model_settings
164            .extract::<OpenAIChatSettings>()
165            .map(ModelSettings::OpenAIChat),
166        SettingsType::GoogleChat => model_settings
167            .extract::<GeminiSettings>()
168            .map(ModelSettings::GoogleChat),
169        SettingsType::Anthropic => model_settings
170            .extract::<AnthropicSettings>()
171            .map(ModelSettings::AnthropicChat),
172        SettingsType::ModelSettings => model_settings.extract::<ModelSettings>(),
173    }
174    .map_err(Into::into)
175}
176
177#[pymethods]
178impl Prompt {
179    /// Creates a new Prompt object.
180    /// Main parsing logic is as follows:
181    /// 1. Extract model settings if provided, otherwise use provider default settings.
182    /// 2. Message and system instructions are expected to be a variant of MessageNum (OpenAIChatMessage, AnthropicMessage or GeminiContent).
183    /// 3. On instantiation, message will be check if is_instance_of pystring. If pystring, provider will be used to map to appropriate message Text type
184    /// 4. If message is a pylist, each item will be checked for is_instance_of pystring or MessageNum variant and converted accordingly.
185    /// 5. If message is a single MessageNum variant, it will be extracted and wrapped in a vec.
186    /// 6. After messages are parsed, a full provider request struct will by built using to_provider_request function.
187    /// # Arguments:
188    /// * `message`: A single message or list of messages representing user input.
189    /// * `model`: The model identifier to use for the prompt.
190    /// * `provider`: The provider to use for the prompt.
191    /// * `system_instruction`: Optional system instruction message or list of messages.
192    /// * `model_settings`: Optional model settings to use for the prompt.
193    /// * `output_type`: Optional output type to enforce structured output.
194    #[new]
195    #[pyo3(signature = (messages, model, provider, system_instructions=None, model_settings=None, output_type=None))]
196    pub fn new(
197        py: Python<'_>,
198        messages: &Bound<'_, PyAny>,
199        model: &str,
200        provider: &Bound<'_, PyAny>,
201        system_instructions: Option<&Bound<'_, PyAny>>,
202        model_settings: Option<&Bound<'_, PyAny>>,
203        output_type: Option<&Bound<'_, PyAny>>, // can be a pydantic model or one of Opsml's predefined outputs
204    ) -> Result<Self, TypeError> {
205        // 1. get model settings if provided
206        let model_settings = model_settings
207            .as_ref()
208            .map(|s| extract_model_settings(s))
209            .transpose()?;
210
211        // 2. extract provider
212        let provider = Provider::extract_provider(provider)?;
213
214        // 3. Parse user messages with "user" role
215        // We'll use this to figure out the type of request struct to create
216        let messages = parse_messages(messages, &provider, Role::User.into())?;
217        let system_instructions = if let Some(sys_inst) = system_instructions {
218            parse_messages(sys_inst, &provider, get_system_role(&provider))?
219        } else {
220            vec![]
221        };
222
223        // 4.  validate response_json_schema
224        let (response_type, response_json_schema) = match output_type {
225            Some(output_type) => {
226                // check if output_type is a pydantic model and extract the model json schema
227                parse_response_to_json(py, output_type)?
228            }
229            None => (ResponseType::Null, None),
230        };
231
232        Self::new_rs(
233            messages,
234            model,
235            provider,
236            system_instructions,
237            model_settings,
238            response_json_schema,
239            response_type,
240        )
241    }
242
243    #[getter]
244    pub fn model_settings<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyAny>, TypeError> {
245        self.request.model_settings(py)
246    }
247
248    #[getter]
249    pub fn model_identifier(&self) -> String {
250        format!("{}:{}", self.provider.as_str(), self.model)
251    }
252
253    #[pyo3(signature = (path = None))]
254    pub fn save_prompt(&self, path: Option<PathBuf>) -> PyResult<PathBuf> {
255        let save_path = path.unwrap_or_else(|| PathBuf::from(SaveName::Prompt));
256        PyHelperFuncs::save_to_json(self, &save_path)?;
257        Ok(save_path)
258    }
259
260    #[staticmethod]
261    pub fn from_path(path: PathBuf) -> Result<Self, TypeError> {
262        let content = std::fs::read_to_string(&path)?;
263
264        let extension = path
265            .extension()
266            .and_then(|ext| ext.to_str())
267            .ok_or_else(|| TypeError::Error(format!("Invalid file path: {:?}", path)))?;
268
269        let mut prompt = match extension.to_lowercase().as_str() {
270            "json" => {
271                let json_value: Prompt = serde_json::from_str(&content)?;
272                Ok(json_value)
273            }
274            "yaml" | "yml" => {
275                let yaml_value: Prompt = serde_yaml::from_str(&content)?;
276                Ok(yaml_value)
277            }
278            _ => Err(TypeError::Error(format!(
279                "Unsupported file extension '{}'. Expected .json, .yaml, or .yml",
280                extension
281            ))),
282        }?;
283
284        if prompt.parameters.is_empty() {
285            // if paramters is empty, extract from messages
286            let system_instructions: Vec<MessageNum> = prompt
287                .request
288                .system_instructions()
289                .iter()
290                .map(|msg| (*msg).clone())
291                .collect();
292            let parameters =
293                Self::extract_variables(prompt.request.messages(), &system_instructions);
294            prompt.parameters = parameters;
295        }
296
297        Ok(prompt)
298    }
299
300    #[staticmethod]
301    pub fn model_validate_json(json_string: String) -> Result<Self, TypeError> {
302        let json_value: Value = serde_json::from_str(&json_string)?;
303        let model: Self = serde_json::from_value(json_value)?;
304
305        Ok(model)
306    }
307
308    pub fn model_dump_json(&self) -> String {
309        serde_json::to_string(self).unwrap()
310    }
311
312    pub fn __str__(&self) -> String {
313        PyHelperFuncs::__str__(self)
314    }
315
316    #[getter]
317    /// Returns all messages as Python objects, including system instructions, user messages, and assistant messages.
318    pub fn all_messages<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyList>, TypeError> {
319        self.request.get_all_py_messages(py)
320    }
321
322    #[getter]
323    /// Returns User messages as Python objects. This means, system instructions are excluded.
324    pub fn messages<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyList>, TypeError> {
325        self.request.get_py_messages(py)
326    }
327
328    #[getter]
329    /// Returns the last User message as a Python object. This means, system instructions are excluded.
330    pub fn message<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyAny>, TypeError> {
331        self.request.get_py_message(py)
332    }
333
334    #[getter]
335    /// Returns the messages as OpenAI ChatMessage Python objects
336    /// This is a helper that provide strict typing when working with OpenAI prompts
337    pub fn openai_messages(&self) -> Result<OpenAIMessageList, TypeError> {
338        if self.provider != Provider::OpenAI {
339            return Err(TypeError::Error(
340                "Prompt provider is not OpenAI".to_string(),
341            ));
342        }
343        let messages = self
344            .request
345            .messages()
346            .iter()
347            .filter(|msg| msg.is_user_message())
348            .filter_map(|msg| match msg {
349                MessageNum::OpenAIMessageV1(m) => Some(m.clone()),
350                _ => None,
351            })
352            .collect::<Vec<_>>();
353        Ok(OpenAIMessageList { messages })
354    }
355
356    #[getter]
357    /// Returns the last message as an OpenAI ChatMessage Python object
358    /// This is a helper that provide strict typing when working with OpenAI prompts
359    pub fn openai_message(&self) -> Result<OpenAIChatMessage, TypeError> {
360        if self.provider != Provider::OpenAI {
361            return Err(TypeError::Error(
362                "Prompt provider is not OpenAI".to_string(),
363            ));
364        }
365        self.request.get_openai_message()
366    }
367
368    #[getter]
369    /// Returns the messages as Google GeminiContent Python objects
370    /// This is a helper that provide strict typing when working with Google/Gemini/Vertex prompts
371    pub fn gemini_messages(&self) -> Result<GeminiContentList, TypeError> {
372        if !self.is_google_provider() {
373            return Err(TypeError::Error(
374                "Prompt provider is not Google, Gemini, or Vertex".to_string(),
375            ));
376        }
377        let messages = self
378            .request
379            .messages()
380            .iter()
381            .filter(|msg| msg.is_user_message())
382            .filter_map(|msg| match msg {
383                MessageNum::GeminiContentV1(m) => Some(m.clone()),
384                _ => None,
385            })
386            .collect::<Vec<_>>();
387
388        Ok(GeminiContentList { messages })
389    }
390
391    #[getter]
392    /// Returns the last message as a Google GeminiContent Python object
393    /// This is a helper that provide strict typing when working with Google/Gemini/Vertex prompts
394    pub fn gemini_message(&self) -> Result<GeminiContent, TypeError> {
395        if !self.is_google_provider() {
396            return Err(TypeError::Error(
397                "Prompt provider is not Google, Gemini, or Vertex".to_string(),
398            ));
399        }
400        self.request.get_gemini_message()
401    }
402
403    #[getter]
404    pub fn anthropic_messages(&self) -> Result<AnthropicMessageList, TypeError> {
405        if self.provider != Provider::Anthropic {
406            return Err(TypeError::Error(
407                "Prompt provider is not Anthropic".to_string(),
408            ));
409        }
410        let messages = self
411            .request
412            .messages()
413            .iter()
414            .filter(|msg| msg.is_user_message())
415            .filter_map(|msg| match msg {
416                MessageNum::AnthropicMessageV1(m) => Some(m.clone()),
417                _ => None,
418            })
419            .collect::<Vec<_>>();
420
421        Ok(AnthropicMessageList { messages })
422    }
423
424    #[getter]
425    /// Returns the last message as an Anthropic MessageParam Python object
426    pub fn anthropic_message(&self) -> Result<AnthropicMessage, TypeError> {
427        if self.provider != Provider::Anthropic {
428            return Err(TypeError::Error(
429                "Prompt provider is not Anthropic".to_string(),
430            ));
431        }
432        self.request.get_anthropic_message()
433    }
434
435    #[getter]
436    pub fn system_instructions<'py>(
437        &self,
438        py: Python<'py>,
439    ) -> Result<Bound<'py, PyList>, TypeError> {
440        self.request.get_py_system_instructions(py)
441    }
442
443    /// Binds a variable in the prompt to a value. This will return a new Prompt with the variable bound to the value.
444    /// This will iterate over all user messages and bind the variable in each message.
445    /// # Arguments:
446    /// * `name`: The name of the variable to bind.
447    /// * `value`: The value to bind the variable to.
448    /// # Returns:
449    /// * `Result<Self, PromptError>`: Returns a new Prompt with the variable bound to the value.
450    #[pyo3(signature = (name=None, value=None, **kwargs))]
451    pub fn bind(
452        &self,
453        name: Option<&str>,
454        value: Option<&Bound<'_, PyAny>>,
455        kwargs: Option<&Bound<'_, PyDict>>,
456    ) -> Result<Self, TypeError> {
457        let mut new_prompt = self.clone();
458
459        if let (Some(name), Some(value)) = (name, value) {
460            let var_value = extract_string_value(value)?;
461            for message in new_prompt.request.messages_mut() {
462                message.bind_mut(name, &var_value)?;
463            }
464        }
465
466        if let Some(kwargs) = kwargs {
467            for (key, val) in kwargs.iter() {
468                let var_name = key.extract::<String>()?;
469                let var_value = extract_string_value(&val)?;
470
471                for message in new_prompt.request.messages_mut() {
472                    message.bind_mut(&var_name, &var_value)?;
473                }
474            }
475        }
476
477        if name.is_none() && kwargs.is_none_or(|k| k.is_empty()) {
478            return Err(TypeError::Error(
479                "Must provide either (name, value) or keyword arguments for binding".to_string(),
480            ));
481        }
482
483        Ok(new_prompt)
484    }
485
486    /// Binds a variable in the prompt to a value. This will mutate the current Prompt and bind the variable in each user message.
487    /// # Arguments:
488    /// * `name`: The name of the variable to bind.
489    /// * `value`: The value to bind the variable to.
490    /// # Returns:
491    /// * `Result<(), PromptError>`: Returns Ok(()) on success or an error if the binding fails.
492    #[pyo3(signature = (name=None, value=None, **kwargs))]
493    pub fn bind_mut(
494        &mut self,
495        name: Option<&str>,
496        value: Option<&Bound<'_, PyAny>>,
497        kwargs: Option<&Bound<'_, PyDict>>,
498    ) -> Result<(), TypeError> {
499        if let (Some(name), Some(value)) = (name, value) {
500            let var_value = extract_string_value(value)?;
501            for message in self.request.messages_mut() {
502                message.bind_mut(name, &var_value)?;
503            }
504        }
505
506        if let Some(kwargs) = kwargs {
507            for (key, val) in kwargs.iter() {
508                let var_name = key.extract::<String>()?;
509                let var_value = extract_string_value(&val)?;
510
511                for message in self.request.messages_mut() {
512                    message.bind_mut(&var_name, &var_value)?;
513                }
514            }
515        }
516
517        if name.is_none() && kwargs.is_none_or(|k| k.is_empty()) {
518            return Err(TypeError::Error(
519                "Must provide either (name, value) or keyword arguments for binding".to_string(),
520            ));
521        }
522
523        Ok(())
524    }
525
526    #[getter]
527    pub fn response_json_schema_pretty(&self) -> Option<String> {
528        Some(PyHelperFuncs::__str__(
529            self.request.response_json_schema().as_ref()?,
530        ))
531    }
532
533    #[getter]
534    #[pyo3(name = "response_json_schema")]
535    pub fn response_json_schema_py(&self) -> Option<String> {
536        Some(self.request.response_json_schema().as_ref()?.to_string())
537    }
538
539    pub fn model_dump<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyAny>, TypeError> {
540        let request = &self.request.to_json()?;
541        Ok(pythonize(py, request)?)
542    }
543}
544
545impl Prompt {
546    pub fn response_json_schema(&self) -> Option<&Value> {
547        self.request.response_json_schema()
548    }
549
550    pub fn new_rs(
551        messages: Vec<MessageNum>,
552        model: &str,
553        provider: Provider,
554        system_instructions: Vec<MessageNum>,
555        model_settings: Option<ModelSettings>,
556        response_json_schema: Option<Value>,
557        response_type: ResponseType,
558    ) -> Result<Self, TypeError> {
559        let model = model.to_string();
560        // get version from crate
561        let version = potato_util::version();
562        // If model_settings is not provided, set model and provider to undefined if missing
563        let model_settings = match model_settings {
564            Some(settings) => {
565                // validates if provider and settings are compatible
566                settings.validate_provider(&provider)?;
567                settings
568            }
569            None => ModelSettings::provider_default_settings(&provider),
570        };
571
572        // extract named parameters in prompt
573        let parameters = Self::extract_variables(&messages, &system_instructions);
574
575        // Build the provider request
576        let request = to_provider_request(
577            messages,
578            system_instructions,
579            model.clone(),
580            model_settings,
581            response_json_schema,
582        )?;
583
584        Ok(Self {
585            request,
586            version,
587            parameters,
588            response_type,
589            model,
590            provider,
591        })
592    }
593
594    fn is_google_provider(&self) -> bool {
595        matches!(
596            self.provider,
597            Provider::Google | Provider::Gemini | Provider::Vertex
598        )
599    }
600
601    pub fn add_tools(&mut self, tools: Vec<AgentToolDefinition>) -> Result<(), TypeError> {
602        self.request.add_tools(tools)
603    }
604
605    pub fn extract_variables(
606        messages: &[MessageNum],
607        system_instructions: &[MessageNum],
608    ) -> Vec<String> {
609        let mut variables = HashSet::new();
610
611        // Extract from system instructions
612        for msg in system_instructions {
613            variables.extend(msg.extract_variables());
614        }
615
616        // Extract from user messages
617        for msg in messages {
618            variables.extend(msg.extract_variables());
619        }
620
621        variables.into_iter().collect()
622    }
623
624    pub fn model_dump_value(&self) -> Value {
625        // Convert the Prompt to a JSON Value
626        serde_json::to_value(self).unwrap_or(Value::Null)
627    }
628
629    pub fn to_request_json(&self) -> Result<Value, TypeError> {
630        // Convert the Prompt to a JSON Value
631        let json_value = serde_json::to_value(self)?;
632
633        Ok(json_value)
634    }
635
636    pub fn set_response_json_schema(
637        &mut self,
638        response_json_schema: Option<Value>,
639        response_type: ResponseType,
640    ) {
641        self.request.set_response_json_schema(response_json_schema);
642        self.response_type = response_type;
643    }
644}
645
646// tests
647#[cfg(test)]
648mod tests {
649    use super::*;
650    use crate::anthropic::v1::request::{
651        Base64ImageSource, Base64PDFSource, ContentBlockParam, DocumentBlockParam, ImageBlockParam,
652        MessageParam, PlainTextSource, TextBlockParam, UrlImageSource, UrlPDFSource,
653    };
654    use crate::google::{DataNum, GeminiContent, Part};
655    use crate::openai::v1::chat::request::{
656        ChatMessage as OpenAIChatMessage, ContentPart, FileContentPart, ImageContentPart,
657        TextContentPart,
658    };
659    use crate::prompt::types::Score;
660    use crate::StructuredOutput;
661
662    fn create_openai_chat_message() -> OpenAIChatMessage {
663        let text_part = TextContentPart::new("What company is this logo from?".to_string());
664        let text_content_part = ContentPart::Text(text_part);
665        OpenAIChatMessage {
666            role: "user".to_string(),
667            content: vec![text_content_part],
668            name: None,
669        }
670    }
671
672    fn create_system_openai_chat_message() -> OpenAIChatMessage {
673        let text_part = TextContentPart::new("system_prompt".to_string());
674        let text_content_part = ContentPart::Text(text_part);
675        OpenAIChatMessage {
676            role: "developer".to_string(),
677            content: vec![text_content_part],
678            name: None,
679        }
680    }
681
682    fn create_openai_image_message() -> OpenAIChatMessage {
683        let image_part = ImageContentPart::new("https://iili.io/3Hs4FMg.png".to_string(), None);
684        let image_content_part = ContentPart::ImageUrl(image_part);
685        OpenAIChatMessage {
686            role: "user".to_string(),
687            content: vec![image_content_part],
688            name: None,
689        }
690    }
691
692    fn create_openai_file_message() -> OpenAIChatMessage {
693        let file_part = FileContentPart::new(
694            Some("filedata".to_string()),
695            Some("fileid".to_string()),
696            Some("filename".to_string()),
697        );
698        let file_content_part = ContentPart::FileContent(file_part);
699        OpenAIChatMessage {
700            role: "user".to_string(),
701            content: vec![file_content_part],
702            name: None,
703        }
704    }
705
706    fn create_anthropic_text_message() -> MessageParam {
707        let text_block =
708            TextBlockParam::new_rs("What company is this logo from?".to_string(), None, None);
709        MessageParam {
710            role: "user".to_string(),
711            content: vec![ContentBlockParam {
712                inner: crate::anthropic::v1::request::ContentBlock::Text(text_block),
713            }],
714        }
715    }
716
717    fn create_anthropic_system_message() -> MessageParam {
718        let text_block = TextBlockParam::new_rs("system_prompt".to_string(), None, None);
719        MessageParam {
720            role: "assistant".to_string(),
721            content: vec![ContentBlockParam {
722                inner: crate::anthropic::v1::request::ContentBlock::Text(text_block),
723            }],
724        }
725    }
726
727    fn create_anthropic_base64_image_message() -> MessageParam {
728        let image_source =
729            Base64ImageSource::new("image/png".to_string(), "base64data".to_string()).unwrap();
730        let image_block = ImageBlockParam {
731            source: crate::anthropic::v1::request::ImageSource::Base64(image_source),
732            cache_control: None,
733            r#type: "image".to_string(),
734        };
735        MessageParam {
736            role: "user".to_string(),
737            content: vec![ContentBlockParam {
738                inner: crate::anthropic::v1::request::ContentBlock::Image(image_block),
739            }],
740        }
741    }
742
743    fn create_anthropic_url_image_message() -> MessageParam {
744        let image_source = UrlImageSource::new("https://iili.io/3Hs4FMg.png".to_string());
745        let image_block = ImageBlockParam {
746            source: crate::anthropic::v1::request::ImageSource::Url(image_source),
747            cache_control: None,
748            r#type: "image".to_string(),
749        };
750        MessageParam {
751            role: "user".to_string(),
752            content: vec![ContentBlockParam {
753                inner: crate::anthropic::v1::request::ContentBlock::Image(image_block),
754            }],
755        }
756    }
757
758    fn create_anthropic_base64_pdf_message() -> MessageParam {
759        let pdf_source = Base64PDFSource::new("base64pdfdata".to_string()).unwrap();
760        let document_block = DocumentBlockParam {
761            source: crate::anthropic::v1::request::DocumentSource::Base64(pdf_source),
762            cache_control: None,
763            title: Some("test_document.pdf".to_string()),
764            context: None,
765            r#type: "document".to_string(),
766            citations: None,
767        };
768        MessageParam {
769            role: "user".to_string(),
770            content: vec![ContentBlockParam {
771                inner: crate::anthropic::v1::request::ContentBlock::Document(document_block),
772            }],
773        }
774    }
775
776    fn create_anthropic_url_pdf_message() -> MessageParam {
777        let pdf_source = UrlPDFSource::new("https://example.com/document.pdf".to_string());
778        let document_block = DocumentBlockParam {
779            source: crate::anthropic::v1::request::DocumentSource::Url(pdf_source),
780            cache_control: None,
781            title: Some("test_document.pdf".to_string()),
782            context: None,
783            r#type: "document".to_string(),
784            citations: None,
785        };
786        MessageParam {
787            role: "user".to_string(),
788            content: vec![ContentBlockParam {
789                inner: crate::anthropic::v1::request::ContentBlock::Document(document_block),
790            }],
791        }
792    }
793
794    fn create_anthropic_plain_text_document_message() -> MessageParam {
795        let text_source = PlainTextSource::new("Plain text document content".to_string());
796        let document_block = DocumentBlockParam {
797            source: crate::anthropic::v1::request::DocumentSource::Text(text_source),
798            cache_control: None,
799            title: Some("text_document.txt".to_string()),
800            context: Some("Context for the document".to_string()),
801            r#type: "document".to_string(),
802            citations: None,
803        };
804        MessageParam {
805            role: "user".to_string(),
806            content: vec![ContentBlockParam {
807                inner: crate::anthropic::v1::request::ContentBlock::Document(document_block),
808            }],
809        }
810    }
811
812    #[test]
813    fn test_task_list_add_and_get() {
814        let text_part = TextContentPart::new("Test prompt. ${param1} ${param2}".to_string());
815        let content_part = ContentPart::Text(text_part);
816        let message = OpenAIChatMessage {
817            role: "user".to_string(),
818            content: vec![content_part],
819            name: None,
820        };
821
822        let prompt = Prompt::new_rs(
823            vec![MessageNum::OpenAIMessageV1(message)],
824            "gpt-4o",
825            Provider::OpenAI,
826            vec![],
827            None,
828            None,
829            ResponseType::Null,
830        )
831        .unwrap();
832
833        // Check if the prompt was created successfully
834        assert_eq!(prompt.request.messages().len(), 1);
835
836        // check prompt parameters
837        assert!(prompt.parameters.len() == 2);
838
839        // sort parameters to ensure order does not affect the test
840        let mut parameters = prompt.parameters.clone();
841        parameters.sort();
842
843        assert_eq!(parameters[0], "param1");
844        assert_eq!(parameters[1], "param2");
845
846        // bind parameter
847        let bound_msg = prompt.request.messages()[0]
848            .bind("param1", "Value1")
849            .unwrap();
850        let bound_msg = bound_msg.bind("param2", "Value2").unwrap();
851
852        // Check if the bound message contains the correct values
853        match bound_msg.clone() {
854            MessageNum::OpenAIMessageV1(msg) => {
855                if let ContentPart::Text(text_part) = &msg.content[0] {
856                    assert_eq!(text_part.text, "Test prompt. Value1 Value2");
857                } else {
858                    panic!("Expected TextContentPart");
859                }
860            }
861            _ => panic!("Expected OpenAIMessageV1"),
862        }
863    }
864
865    #[test]
866    fn test_image_prompt() {
867        let text_message = create_openai_chat_message();
868        let image_message = create_openai_image_message();
869
870        let system_text_part = TextContentPart::new("system_prompt".to_string());
871        let system_text_content_part = ContentPart::Text(system_text_part);
872
873        let system_text_message = OpenAIChatMessage {
874            role: "assistant".to_string(),
875            content: vec![system_text_content_part],
876            name: None,
877        };
878
879        let prompt = Prompt::new_rs(
880            vec![
881                MessageNum::OpenAIMessageV1(text_message),
882                MessageNum::OpenAIMessageV1(image_message),
883            ],
884            "gpt-4o",
885            Provider::OpenAI,
886            vec![MessageNum::OpenAIMessageV1(system_text_message)],
887            None,
888            None,
889            ResponseType::Null,
890        )
891        .unwrap();
892
893        // Check the first user message
894        if let MessageNum::OpenAIMessageV1(msg) = &prompt.request.messages()[1] {
895            if let ContentPart::Text(text_part) = &msg.content[0] {
896                assert_eq!(text_part.text, "What company is this logo from?");
897            } else {
898                panic!("Expected TextContentPart for the first user message");
899            }
900        } else {
901            panic!("Expected OpenAIMessageV1 for the first user message");
902        }
903
904        // Check the second user message (ImageUrl)
905        if let MessageNum::OpenAIMessageV1(msg) = &prompt.request.messages()[2] {
906            if let ContentPart::ImageUrl(image_url) = &msg.content[0] {
907                assert_eq!(image_url.image_url.url, "https://iili.io/3Hs4FMg.png");
908                assert_eq!(image_url.r#type, "image_url");
909            } else {
910                panic!("Expected ContentPart::Image for the second user message");
911            }
912        } else {
913            panic!("Expected OpenAIMessageV1 for the second user message");
914        }
915    }
916
917    #[test]
918    fn test_document_prompt() {
919        let text_message = create_openai_chat_message();
920        let file_message = create_openai_file_message();
921        let system_message = create_system_openai_chat_message();
922
923        let prompt = Prompt::new_rs(
924            vec![
925                MessageNum::OpenAIMessageV1(text_message),
926                MessageNum::OpenAIMessageV1(file_message),
927            ],
928            "gpt-4o",
929            Provider::OpenAI,
930            vec![MessageNum::OpenAIMessageV1(system_message)],
931            None,
932            None,
933            ResponseType::Null,
934        )
935        .unwrap();
936
937        // Check the 2nd user message (file)
938        if let MessageNum::OpenAIMessageV1(msg) = &prompt.request.messages()[2] {
939            if let ContentPart::FileContent(file_content) = &msg.content[0] {
940                assert_eq!(file_content.file.file_id.as_ref().unwrap(), "fileid");
941                assert_eq!(file_content.file.filename.as_ref().unwrap(), "filename");
942            } else {
943                panic!("Expected ContentPart::FileContent for the second user message");
944            }
945        } else {
946            panic!("Expected OpenAIMessageV1 for the first user message");
947        }
948    }
949
950    #[test]
951    fn test_response_format_score() {
952        let text_message = create_openai_chat_message();
953        let prompt = Prompt::new_rs(
954            vec![MessageNum::OpenAIMessageV1(text_message)],
955            "gpt-4o",
956            Provider::OpenAI,
957            vec![],
958            None,
959            Some(Score::get_structured_output_schema()),
960            ResponseType::Null,
961        )
962        .unwrap();
963
964        // Check if the response json schema is set correctly
965        assert!(prompt.response_json_schema().is_some());
966    }
967
968    #[test]
969    fn test_anthropic_text_message_binding() {
970        let text_block =
971            TextBlockParam::new_rs("Test prompt. ${param1} ${param2}".to_string(), None, None);
972        let message = MessageParam {
973            role: "user".to_string(),
974            content: vec![ContentBlockParam {
975                inner: crate::anthropic::v1::request::ContentBlock::Text(text_block),
976            }],
977        };
978
979        let prompt = Prompt::new_rs(
980            vec![MessageNum::AnthropicMessageV1(message)],
981            "claude-3-5-sonnet-20241022",
982            Provider::Anthropic,
983            vec![],
984            None,
985            None,
986            ResponseType::Null,
987        )
988        .unwrap();
989
990        assert_eq!(prompt.request.messages().len(), 1);
991        assert_eq!(prompt.parameters.len(), 2);
992
993        let mut parameters = prompt.parameters.clone();
994        parameters.sort();
995        assert_eq!(parameters[0], "param1");
996        assert_eq!(parameters[1], "param2");
997
998        // Test parameter binding
999        let bound_msg = prompt.request.messages()[0]
1000            .bind("param1", "Value1")
1001            .unwrap();
1002        let bound_msg = bound_msg.bind("param2", "Value2").unwrap();
1003
1004        match bound_msg {
1005            MessageNum::AnthropicMessageV1(msg) => {
1006                if let crate::anthropic::v1::request::ContentBlock::Text(text_block) =
1007                    &msg.content[0].inner
1008                {
1009                    assert_eq!(text_block.text, "Test prompt. Value1 Value2");
1010                } else {
1011                    panic!("Expected TextBlockParam");
1012                }
1013            }
1014            _ => panic!("Expected AnthropicMessageV1"),
1015        }
1016    }
1017
1018    #[test]
1019    fn test_anthropic_url_image_prompt() {
1020        let text_message = create_anthropic_text_message();
1021        let image_message = create_anthropic_url_image_message();
1022        let system_message = create_anthropic_system_message();
1023
1024        let prompt = Prompt::new_rs(
1025            vec![
1026                MessageNum::AnthropicMessageV1(text_message),
1027                MessageNum::AnthropicMessageV1(image_message),
1028            ],
1029            "claude-3-5-sonnet-20241022",
1030            Provider::Anthropic,
1031            vec![MessageNum::AnthropicMessageV1(system_message)],
1032            None,
1033            None,
1034            ResponseType::Null,
1035        )
1036        .unwrap();
1037
1038        // Check first message (text)
1039        if let MessageNum::AnthropicMessageV1(msg) = &prompt.request.messages()[0] {
1040            if let crate::anthropic::v1::request::ContentBlock::Text(text_block) =
1041                &msg.content[0].inner
1042            {
1043                assert_eq!(text_block.text, "What company is this logo from?");
1044            } else {
1045                panic!("Expected TextBlock for first message");
1046            }
1047        } else {
1048            panic!("Expected AnthropicMessageV1");
1049        }
1050
1051        // Check second message (image URL)
1052        if let MessageNum::AnthropicMessageV1(msg) = &prompt.request.messages()[1] {
1053            if let crate::anthropic::v1::request::ContentBlock::Image(image_block) =
1054                &msg.content[0].inner
1055            {
1056                match &image_block.source {
1057                    crate::anthropic::v1::request::ImageSource::Url(url_source) => {
1058                        assert_eq!(url_source.url, "https://iili.io/3Hs4FMg.png");
1059                        assert_eq!(url_source.r#type, "url");
1060                    }
1061                    _ => panic!("Expected URL image source"),
1062                }
1063                assert_eq!(image_block.r#type, "image");
1064            } else {
1065                panic!("Expected ImageBlock for second message");
1066            }
1067        } else {
1068            panic!("Expected AnthropicMessageV1");
1069        }
1070    }
1071
1072    #[test]
1073    fn test_anthropic_base64_image_prompt() {
1074        let text_message = create_anthropic_text_message();
1075        let image_message = create_anthropic_base64_image_message();
1076
1077        let prompt = Prompt::new_rs(
1078            vec![
1079                MessageNum::AnthropicMessageV1(text_message),
1080                MessageNum::AnthropicMessageV1(image_message),
1081            ],
1082            "claude-3-5-sonnet-20241022",
1083            Provider::Anthropic,
1084            vec![],
1085            None,
1086            None,
1087            ResponseType::Null,
1088        )
1089        .unwrap();
1090
1091        // Check second message (base64 image)
1092        if let MessageNum::AnthropicMessageV1(msg) = &prompt.request.messages()[1] {
1093            if let crate::anthropic::v1::request::ContentBlock::Image(image_block) =
1094                &msg.content[0].inner
1095            {
1096                match &image_block.source {
1097                    crate::anthropic::v1::request::ImageSource::Base64(base64_source) => {
1098                        assert_eq!(base64_source.media_type, "image/png");
1099                        assert_eq!(base64_source.data, "base64data");
1100                        assert_eq!(base64_source.r#type, "base64");
1101                    }
1102                    _ => panic!("Expected Base64 image source"),
1103                }
1104            } else {
1105                panic!("Expected ImageBlock");
1106            }
1107        } else {
1108            panic!("Expected AnthropicMessageV1");
1109        }
1110    }
1111
1112    // Test: Anthropic PDF document (base64)
1113    #[test]
1114    fn test_anthropic_base64_pdf_document_prompt() {
1115        let text_message = create_anthropic_text_message();
1116        let pdf_message = create_anthropic_base64_pdf_message();
1117        let system_message = create_anthropic_system_message();
1118
1119        let prompt = Prompt::new_rs(
1120            vec![
1121                MessageNum::AnthropicMessageV1(text_message),
1122                MessageNum::AnthropicMessageV1(pdf_message),
1123            ],
1124            "claude-3-5-sonnet-20241022",
1125            Provider::Anthropic,
1126            vec![MessageNum::AnthropicMessageV1(system_message)],
1127            None,
1128            None,
1129            ResponseType::Null,
1130        )
1131        .unwrap();
1132
1133        // Check second message (PDF document)
1134        if let MessageNum::AnthropicMessageV1(msg) = &prompt.request.messages()[1] {
1135            if let crate::anthropic::v1::request::ContentBlock::Document(document_block) =
1136                &msg.content[0].inner
1137            {
1138                match &document_block.source {
1139                    crate::anthropic::v1::request::DocumentSource::Base64(pdf_source) => {
1140                        assert_eq!(pdf_source.media_type, "application/pdf");
1141                        assert_eq!(pdf_source.data, "base64pdfdata");
1142                        assert_eq!(pdf_source.r#type, "base64");
1143                    }
1144                    _ => panic!("Expected Base64 PDF source"),
1145                }
1146                assert_eq!(document_block.r#type, "document");
1147                assert_eq!(document_block.title.as_ref().unwrap(), "test_document.pdf");
1148            } else {
1149                panic!("Expected DocumentBlock");
1150            }
1151        } else {
1152            panic!("Expected AnthropicMessageV1");
1153        }
1154    }
1155
1156    // Test: Anthropic URL PDF document
1157    #[test]
1158    fn test_anthropic_url_pdf_document_prompt() {
1159        let text_message = create_anthropic_text_message();
1160        let pdf_message = create_anthropic_url_pdf_message();
1161
1162        let prompt = Prompt::new_rs(
1163            vec![
1164                MessageNum::AnthropicMessageV1(text_message),
1165                MessageNum::AnthropicMessageV1(pdf_message),
1166            ],
1167            "claude-3-5-sonnet-20241022",
1168            Provider::Anthropic,
1169            vec![],
1170            None,
1171            None,
1172            ResponseType::Null,
1173        )
1174        .unwrap();
1175
1176        // Check second message (URL PDF)
1177        if let MessageNum::AnthropicMessageV1(msg) = &prompt.request.messages()[1] {
1178            if let crate::anthropic::v1::request::ContentBlock::Document(document_block) =
1179                &msg.content[0].inner
1180            {
1181                match &document_block.source {
1182                    crate::anthropic::v1::request::DocumentSource::Url(url_source) => {
1183                        assert_eq!(url_source.url, "https://example.com/document.pdf");
1184                        assert_eq!(url_source.r#type, "url");
1185                    }
1186                    _ => panic!("Expected URL PDF source"),
1187                }
1188            } else {
1189                panic!("Expected DocumentBlock");
1190            }
1191        } else {
1192            panic!("Expected AnthropicMessageV1");
1193        }
1194    }
1195
1196    // Test: Anthropic plain text document
1197    #[test]
1198    fn test_anthropic_plain_text_document_prompt() {
1199        let text_message = create_anthropic_text_message();
1200        let text_doc_message = create_anthropic_plain_text_document_message();
1201
1202        let prompt = Prompt::new_rs(
1203            vec![
1204                MessageNum::AnthropicMessageV1(text_message),
1205                MessageNum::AnthropicMessageV1(text_doc_message),
1206            ],
1207            "claude-3-5-sonnet-20241022",
1208            Provider::Anthropic,
1209            vec![],
1210            None,
1211            None,
1212            ResponseType::Null,
1213        )
1214        .unwrap();
1215
1216        // Check second message (plain text document)
1217        if let MessageNum::AnthropicMessageV1(msg) = &prompt.request.messages()[1] {
1218            if let crate::anthropic::v1::request::ContentBlock::Document(document_block) =
1219                &msg.content[0].inner
1220            {
1221                match &document_block.source {
1222                    crate::anthropic::v1::request::DocumentSource::Text(text_source) => {
1223                        assert_eq!(text_source.media_type, "text/plain");
1224                        assert_eq!(text_source.data, "Plain text document content");
1225                        assert_eq!(text_source.r#type, "text");
1226                    }
1227                    _ => panic!("Expected Text document source"),
1228                }
1229                assert_eq!(
1230                    document_block.context.as_ref().unwrap(),
1231                    "Context for the document"
1232                );
1233            } else {
1234                panic!("Expected DocumentBlock");
1235            }
1236        } else {
1237            panic!("Expected AnthropicMessageV1");
1238        }
1239    }
1240
1241    // Test: Mixed Anthropic content (text + multiple documents)
1242    #[test]
1243    fn test_anthropic_mixed_content_prompt() {
1244        let text_message = create_anthropic_text_message();
1245        let pdf_message = create_anthropic_base64_pdf_message();
1246        let text_doc_message = create_anthropic_plain_text_document_message();
1247        let system_message = create_anthropic_system_message();
1248
1249        let prompt = Prompt::new_rs(
1250            vec![
1251                MessageNum::AnthropicMessageV1(text_message),
1252                MessageNum::AnthropicMessageV1(pdf_message),
1253                MessageNum::AnthropicMessageV1(text_doc_message),
1254            ],
1255            "claude-3-5-sonnet-20241022",
1256            Provider::Anthropic,
1257            vec![MessageNum::AnthropicMessageV1(system_message)],
1258            None,
1259            None,
1260            ResponseType::Null,
1261        )
1262        .unwrap();
1263
1264        assert_eq!(prompt.request.messages().len(), 3);
1265        assert_eq!(prompt.request.system_instructions().len(), 1);
1266        assert_eq!(prompt.provider, Provider::Anthropic);
1267        assert_eq!(prompt.model, "claude-3-5-sonnet-20241022");
1268    }
1269
1270    // gemini test
1271    #[test]
1272    fn test_gemini_chat_message() {
1273        let text = Part::from_text("Test prompt. ${param1} ${param2}".to_string());
1274        let message = GeminiContent {
1275            role: "user".to_string(),
1276            parts: vec![text],
1277        };
1278
1279        let prompt = Prompt::new_rs(
1280            vec![MessageNum::GeminiContentV1(message)],
1281            "gemini-1.5-pro",
1282            Provider::Google,
1283            vec![],
1284            None,
1285            None,
1286            ResponseType::Null,
1287        )
1288        .unwrap();
1289
1290        assert_eq!(prompt.request.messages().len(), 1);
1291        assert_eq!(prompt.parameters.len(), 2);
1292
1293        let mut parameters = prompt.parameters.clone();
1294        parameters.sort();
1295        assert_eq!(parameters[0], "param1");
1296        assert_eq!(parameters[1], "param2");
1297
1298        // Test parameter binding
1299        let bound_msg = prompt.request.messages()[0]
1300            .bind("param1", "Value1")
1301            .unwrap();
1302        let bound_msg = bound_msg.bind("param2", "Value2").unwrap();
1303
1304        match bound_msg {
1305            MessageNum::GeminiContentV1(msg) => {
1306                if let DataNum::Text(text_part) = &msg.parts[0].data {
1307                    assert_eq!(text_part, "Test prompt. Value1 Value2");
1308                } else {
1309                    panic!("Expected Text Part");
1310                }
1311            }
1312            _ => panic!("Expected GeminiContentV1"),
1313        }
1314    }
1315}