1use crate::anthropic::v1::request::{AnthropicSettings, MessageParam as AnthropicMessage};
2use crate::error::TypeError;
3use crate::google::v1::generate::request::{GeminiContent, GeminiSettings};
4use crate::openai::v1::chat::request::ChatMessage as OpenAIChatMessage;
5use crate::openai::v1::chat::settings::OpenAIChatSettings;
6use crate::prompt::builder::{to_provider_request, ProviderRequest};
7use crate::prompt::settings::ModelSettings;
8use crate::prompt::types::parse_response_to_json;
9use crate::prompt::types::ResponseType;
10use crate::prompt::types::Role;
11use crate::prompt::{AnthropicMessageList, GeminiContentList, MessageNum, OpenAIMessageList};
12use crate::tools::AgentToolDefinition;
13use crate::traits::MessageFactory;
14use crate::SettingsType;
15use crate::{Provider, SaveName};
16use potato_util::utils::extract_string_value;
17use potato_util::PyHelperFuncs;
18use potatohead_macro::try_extract_message;
19use pyo3::prelude::*;
20use pyo3::types::{PyDict, PyList, PyString, PyTuple};
21use pythonize::pythonize;
22use serde::{Deserialize, Deserializer, Serialize};
23use serde_json::Value;
24use std::collections::BTreeSet;
25use std::path::PathBuf;
26
27fn deserialize_string_or_vec<'de, D>(deserializer: D) -> Result<Vec<String>, D::Error>
30where
31 D: Deserializer<'de>,
32{
33 #[derive(Deserialize)]
34 #[serde(untagged)]
35 enum StringOrVec {
36 Single(String),
37 List(Vec<String>),
38 }
39
40 match StringOrVec::deserialize(deserializer)? {
41 StringOrVec::Single(s) => Ok(vec![s]),
42 StringOrVec::List(v) => Ok(v),
43 }
44}
45
46#[derive(Debug, Deserialize)]
61pub struct GenericPromptConfig {
62 model: String,
63 provider: String,
64 #[serde(deserialize_with = "deserialize_string_or_vec")]
65 messages: Vec<String>,
66 #[serde(default)]
67 system_instructions: Option<Vec<String>>,
68 #[serde(default)]
69 settings: Option<Value>,
70 response_format: Option<Value>,
71}
72
73fn create_message_for_provider(
74 content: String,
75 provider: &Provider,
76 role: &str,
77) -> Result<MessageNum, TypeError> {
78 match provider {
79 Provider::OpenAI => {
80 OpenAIChatMessage::from_text(content, role).map(MessageNum::OpenAIMessageV1)
81 }
82 Provider::Anthropic => {
83 AnthropicMessage::from_text(content, role).map(MessageNum::AnthropicMessageV1)
84 }
85 Provider::Gemini | Provider::Google | Provider::Vertex => {
86 GeminiContent::from_text(content, role).map(MessageNum::GeminiContentV1)
87 }
88 _ => Err(TypeError::Error(format!(
89 "Unsupported provider for message creation: {:?}",
90 provider
91 ))),
92 }
93}
94
95fn parse_single_message(
96 message: &Bound<'_, PyAny>,
97 provider: &Provider,
98 default_role: &str,
99) -> Result<MessageNum, TypeError> {
100 if message.is_instance_of::<PyString>() {
102 let text = message.extract::<String>()?;
103 return create_message_for_provider(text, provider, default_role);
104 }
105
106 try_extract_message!(
108 message,
109 OpenAIChatMessage => MessageNum::OpenAIMessageV1,
110 AnthropicMessage => MessageNum::AnthropicMessageV1,
111 GeminiContent => MessageNum::GeminiContentV1,
112 );
113
114 Err(TypeError::InvalidMessageTypeInList(
115 message.get_type().name()?.to_string(),
116 ))
117}
118
119fn parse_messages(
120 messages: &Bound<'_, PyAny>,
121 provider: &Provider,
122 default_role: &str,
123) -> Result<Vec<MessageNum>, TypeError> {
124 let mut messages =
126 if !messages.is_instance_of::<PyList>() && !messages.is_instance_of::<PyTuple>() {
127 vec![parse_single_message(messages, provider, default_role)?]
128 } else {
129 messages
131 .try_iter()?
132 .map(|item| {
133 let item = item?;
134 parse_single_message(&item, provider, default_role)
135 })
136 .collect::<Result<Vec<_>, _>>()?
137 };
138
139 if provider == &Provider::Anthropic
142 && (default_role == Role::System.as_str()
143 || default_role == Role::Assistant.as_str()
144 || default_role == Role::Developer.as_str())
145 {
146 for msg in messages.iter_mut() {
147 msg.anthropic_message_to_system_message()?;
148 }
149 }
150
151 Ok(messages)
152}
153
154fn get_system_role(provider: &Provider) -> &'static str {
155 match provider {
156 Provider::OpenAI => Role::Developer.into(),
157 Provider::Gemini | Provider::Vertex | Provider::Google => Role::Model.into(),
158 Provider::Anthropic => Role::System.into(),
159 _ => Role::System.into(),
160 }
161}
162
163pub fn extract_system_instructions(
165 system_instruction: Option<&Bound<'_, PyAny>>,
166 provider: &Provider,
167) -> Result<Option<Vec<MessageNum>>, TypeError> {
168 let system_instructions = if let Some(sys_inst) = system_instruction {
169 Some(parse_messages(
170 sys_inst,
171 provider,
172 get_system_role(provider),
173 )?)
174 } else {
175 None
176 };
177
178 Ok(system_instructions)
179}
180
181#[pyclass]
182#[derive(Debug, Serialize, Clone, PartialEq)]
183pub struct Prompt {
184 pub request: ProviderRequest,
185
186 #[pyo3(get)]
187 pub model: String,
188
189 #[pyo3(get)]
190 pub provider: Provider,
191
192 pub version: String,
193
194 #[pyo3(get)]
195 #[serde(default)]
196 pub parameters: Vec<String>,
197
198 #[serde(default)]
199 pub response_type: ResponseType,
200}
201
202fn extract_model_settings(model_settings: &Bound<'_, PyAny>) -> Result<ModelSettings, TypeError> {
204 let settings_type = model_settings
205 .call_method0("settings_type")?
206 .extract::<SettingsType>()?;
207
208 match settings_type {
209 SettingsType::OpenAIChat => model_settings
210 .extract::<OpenAIChatSettings>()
211 .map(ModelSettings::OpenAIChat),
212 SettingsType::GoogleChat => model_settings
213 .extract::<GeminiSettings>()
214 .map(ModelSettings::GoogleChat),
215 SettingsType::Anthropic => model_settings
216 .extract::<AnthropicSettings>()
217 .map(ModelSettings::AnthropicChat),
218 SettingsType::ModelSettings => model_settings.extract::<ModelSettings>(),
219 }
220 .map_err(Into::into)
221}
222
223#[derive(Debug, Deserialize)]
224#[serde(untagged)]
225enum PromptFormat {
226 Generic(GenericPromptConfig),
227 Full(Box<PromptInternal>),
228}
229
230#[derive(Debug, Deserialize)]
231struct PromptInternal {
232 request: ProviderRequest,
233 model: String,
234 provider: Provider,
235 version: String,
236 #[serde(default)]
237 parameters: Vec<String>,
238 #[serde(default)]
239 response_type: ResponseType,
240}
241
242impl<'de> Deserialize<'de> for Prompt {
243 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
244 where
245 D: serde::Deserializer<'de>,
246 {
247 let format = PromptFormat::deserialize(deserializer)?;
248
249 match format {
250 PromptFormat::Generic(config) => Self::from_generic_config(config)
251 .map_err(|e| serde::de::Error::custom(e.to_string())),
252 PromptFormat::Full(internal) => Ok(Prompt {
253 request: internal.request,
254 model: internal.model,
255 provider: internal.provider,
256 version: internal.version,
257 parameters: internal.parameters,
258 response_type: internal.response_type,
259 }),
260 }
261 }
262}
263
264#[pymethods]
265impl Prompt {
266 #[new]
282 #[pyo3(signature = (messages, model, provider, system_instructions=None, model_settings=None, output_type=None))]
283 pub fn new(
284 py: Python<'_>,
285 messages: &Bound<'_, PyAny>,
286 model: &str,
287 provider: &Bound<'_, PyAny>,
288 system_instructions: Option<&Bound<'_, PyAny>>,
289 model_settings: Option<&Bound<'_, PyAny>>,
290 output_type: Option<&Bound<'_, PyAny>>, ) -> Result<Self, TypeError> {
292 let model_settings = model_settings
294 .as_ref()
295 .map(|s| extract_model_settings(s))
296 .transpose()?;
297
298 let provider = Provider::extract_provider(provider)?;
300
301 let messages = parse_messages(messages, &provider, Role::User.into())?;
304 let system_instructions = if let Some(sys_inst) = system_instructions {
305 parse_messages(sys_inst, &provider, get_system_role(&provider))?
306 } else {
307 vec![]
308 };
309
310 let (response_type, response_json_schema) = match output_type {
312 Some(output_type) => {
313 parse_response_to_json(py, output_type)?
315 }
316 None => (ResponseType::Null, None),
317 };
318
319 Self::new_rs(
320 messages,
321 model,
322 provider,
323 system_instructions,
324 model_settings,
325 response_json_schema,
326 response_type,
327 )
328 }
329
330 #[getter]
331 pub fn model_settings<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyAny>, TypeError> {
332 self.request.model_settings(py)
333 }
334
335 #[getter]
336 pub fn model_identifier(&self) -> String {
337 format!("{}:{}", self.provider.as_str(), self.model)
338 }
339
340 #[pyo3(signature = (path = None))]
341 pub fn save_prompt(&self, path: Option<PathBuf>) -> PyResult<PathBuf> {
342 let save_path = path.unwrap_or_else(|| PathBuf::from(SaveName::Prompt));
343 PyHelperFuncs::save_to_json(self, &save_path)?;
344 Ok(save_path)
345 }
346
347 #[staticmethod]
348 pub fn from_path(path: PathBuf) -> Result<Self, TypeError> {
349 let content = std::fs::read_to_string(&path)?;
350
351 let extension = path
352 .extension()
353 .and_then(|ext| ext.to_str())
354 .ok_or_else(|| TypeError::Error(format!("Invalid file path: {:?}", path)))?;
355
356 let mut prompt: Prompt = match extension.to_lowercase().as_str() {
357 "json" => serde_json::from_str(&content)?,
358 "yaml" | "yml" => serde_yaml::from_str(&content)?,
359 _ => {
360 return Err(TypeError::Error(format!(
361 "Unsupported file extension '{}'. Expected .json, .yaml, or .yml",
362 extension
363 )))
364 }
365 };
366
367 if prompt.parameters.is_empty() {
368 let system_instructions: Vec<MessageNum> = prompt
369 .request
370 .system_instructions()
371 .iter()
372 .map(|msg| (*msg).clone())
373 .collect();
374 let parameters =
375 Self::extract_variables(prompt.request.messages(), &system_instructions);
376 prompt.parameters = parameters;
377 }
378
379 Ok(prompt)
380 }
381
382 #[staticmethod]
383 pub fn model_validate_json(json_string: String) -> Result<Self, TypeError> {
384 let json_value: Value = serde_json::from_str(&json_string)?;
385 let model: Self = serde_json::from_value(json_value)?;
386
387 Ok(model)
388 }
389
390 pub fn model_dump_json(&self) -> String {
391 serde_json::to_string(self).unwrap()
392 }
393
394 pub fn __str__(&self) -> String {
395 PyHelperFuncs::__str__(self)
396 }
397
398 #[getter]
399 pub fn all_messages<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyList>, TypeError> {
401 self.request.get_all_py_messages(py)
402 }
403
404 #[getter]
405 pub fn messages<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyList>, TypeError> {
407 self.request.get_py_messages(py)
408 }
409
410 #[getter]
411 pub fn message<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyAny>, TypeError> {
413 self.request.get_py_message(py)
414 }
415
416 #[getter]
417 pub fn openai_messages(&self) -> Result<OpenAIMessageList, TypeError> {
420 if self.provider != Provider::OpenAI {
421 return Err(TypeError::Error(
422 "Prompt provider is not OpenAI".to_string(),
423 ));
424 }
425 let messages = self
426 .request
427 .messages()
428 .iter()
429 .filter(|msg| msg.is_user_message())
430 .filter_map(|msg| match msg {
431 MessageNum::OpenAIMessageV1(m) => Some(m.clone()),
432 _ => None,
433 })
434 .collect::<Vec<_>>();
435 Ok(OpenAIMessageList { messages })
436 }
437
438 #[getter]
439 pub fn openai_message(&self) -> Result<OpenAIChatMessage, TypeError> {
442 if self.provider != Provider::OpenAI {
443 return Err(TypeError::Error(
444 "Prompt provider is not OpenAI".to_string(),
445 ));
446 }
447 self.request.get_openai_message()
448 }
449
450 #[getter]
451 pub fn gemini_messages(&self) -> Result<GeminiContentList, TypeError> {
454 if !self.is_google_provider() {
455 return Err(TypeError::Error(
456 "Prompt provider is not Google, Gemini, or Vertex".to_string(),
457 ));
458 }
459 let messages = self
460 .request
461 .messages()
462 .iter()
463 .filter(|msg| msg.is_user_message())
464 .filter_map(|msg| match msg {
465 MessageNum::GeminiContentV1(m) => Some(m.clone()),
466 _ => None,
467 })
468 .collect::<Vec<_>>();
469
470 Ok(GeminiContentList { messages })
471 }
472
473 #[getter]
474 pub fn gemini_message(&self) -> Result<GeminiContent, TypeError> {
477 if !self.is_google_provider() {
478 return Err(TypeError::Error(
479 "Prompt provider is not Google, Gemini, or Vertex".to_string(),
480 ));
481 }
482 self.request.get_gemini_message()
483 }
484
485 #[getter]
486 pub fn anthropic_messages(&self) -> Result<AnthropicMessageList, TypeError> {
487 if self.provider != Provider::Anthropic {
488 return Err(TypeError::Error(
489 "Prompt provider is not Anthropic".to_string(),
490 ));
491 }
492 let messages = self
493 .request
494 .messages()
495 .iter()
496 .filter(|msg| msg.is_user_message())
497 .filter_map(|msg| match msg {
498 MessageNum::AnthropicMessageV1(m) => Some(m.clone()),
499 _ => None,
500 })
501 .collect::<Vec<_>>();
502
503 Ok(AnthropicMessageList { messages })
504 }
505
506 #[getter]
507 pub fn anthropic_message(&self) -> Result<AnthropicMessage, TypeError> {
509 if self.provider != Provider::Anthropic {
510 return Err(TypeError::Error(
511 "Prompt provider is not Anthropic".to_string(),
512 ));
513 }
514 self.request.get_anthropic_message()
515 }
516
517 #[getter]
518 pub fn system_instructions<'py>(
519 &self,
520 py: Python<'py>,
521 ) -> Result<Bound<'py, PyList>, TypeError> {
522 self.request.get_py_system_instructions(py)
523 }
524
525 #[pyo3(signature = (name=None, value=None, **kwargs))]
533 pub fn bind(
534 &self,
535 name: Option<&str>,
536 value: Option<&Bound<'_, PyAny>>,
537 kwargs: Option<&Bound<'_, PyDict>>,
538 ) -> Result<Self, TypeError> {
539 let mut new_prompt = self.clone();
540
541 if let (Some(name), Some(value)) = (name, value) {
542 let var_value = extract_string_value(value)?;
543 for message in new_prompt.request.messages_mut() {
544 message.bind_mut(name, &var_value)?;
545 }
546 }
547
548 if let Some(kwargs) = kwargs {
549 for (key, val) in kwargs.iter() {
550 let var_name = key.extract::<String>()?;
551 let var_value = extract_string_value(&val)?;
552
553 for message in new_prompt.request.messages_mut() {
554 message.bind_mut(&var_name, &var_value)?;
555 }
556 }
557 }
558
559 if name.is_none() && kwargs.is_none_or(|k| k.is_empty()) {
560 return Err(TypeError::Error(
561 "Must provide either (name, value) or keyword arguments for binding".to_string(),
562 ));
563 }
564
565 Ok(new_prompt)
566 }
567
568 #[pyo3(signature = (name=None, value=None, **kwargs))]
575 pub fn bind_mut(
576 &mut self,
577 name: Option<&str>,
578 value: Option<&Bound<'_, PyAny>>,
579 kwargs: Option<&Bound<'_, PyDict>>,
580 ) -> Result<(), TypeError> {
581 if let (Some(name), Some(value)) = (name, value) {
582 let var_value = extract_string_value(value)?;
583 for message in self.request.messages_mut() {
584 message.bind_mut(name, &var_value)?;
585 }
586 }
587
588 if let Some(kwargs) = kwargs {
589 for (key, val) in kwargs.iter() {
590 let var_name = key.extract::<String>()?;
591 let var_value = extract_string_value(&val)?;
592
593 for message in self.request.messages_mut() {
594 message.bind_mut(&var_name, &var_value)?;
595 }
596 }
597 }
598
599 if name.is_none() && kwargs.is_none_or(|k| k.is_empty()) {
600 return Err(TypeError::Error(
601 "Must provide either (name, value) or keyword arguments for binding".to_string(),
602 ));
603 }
604
605 Ok(())
606 }
607
608 #[getter]
609 pub fn response_json_schema_pretty(&self) -> Option<String> {
610 Some(PyHelperFuncs::__str__(
611 self.request.response_json_schema().as_ref()?,
612 ))
613 }
614
615 #[getter]
616 #[pyo3(name = "response_json_schema")]
617 pub fn response_json_schema_py(&self) -> Option<String> {
618 Some(self.request.response_json_schema().as_ref()?.to_string())
619 }
620
621 pub fn model_dump<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyAny>, TypeError> {
622 let request = &self.request.to_json()?;
623 Ok(pythonize(py, request)?)
624 }
625}
626
627impl Prompt {
628 pub fn from_generic_config(config: GenericPromptConfig) -> Result<Self, TypeError> {
631 if config.messages.is_empty() {
633 return Err(TypeError::Error(
634 "Prompt has no messages. Generic prompt format requires at least one message."
635 .to_string(),
636 ));
637 }
638
639 let provider = Provider::from_string(&config.provider)?;
641
642 let messages: Vec<MessageNum> = config
644 .messages
645 .into_iter()
646 .map(|msg| create_message_for_provider(msg, &provider, Role::User.as_str()))
647 .collect::<Result<Vec<_>, _>>()?;
648
649 let system_instructions = if let Some(sys_inst) = config.system_instructions {
651 sys_inst
652 .into_iter()
653 .map(|msg| create_message_for_provider(msg, &provider, get_system_role(&provider)))
654 .collect::<Result<Vec<_>, _>>()?
655 } else {
656 Vec::new()
657 };
658
659 let model_settings = if let Some(settings) = config.settings {
661 Some(Self::settings_from_value(settings, &provider)?)
662 } else {
663 None
664 };
665
666 Self::new_rs(
668 messages,
669 &config.model,
670 provider,
671 system_instructions,
672 model_settings,
673 config.response_format,
674 ResponseType::Null,
675 )
676 }
677
678 fn settings_from_value(value: Value, provider: &Provider) -> Result<ModelSettings, TypeError> {
681 match provider {
682 Provider::OpenAI => {
683 let settings: OpenAIChatSettings = serde_json::from_value(value)?;
684 Ok(ModelSettings::OpenAIChat(settings))
685 }
686 Provider::Anthropic => {
687 let settings: AnthropicSettings = serde_json::from_value(value)?;
688 Ok(ModelSettings::AnthropicChat(settings))
689 }
690 Provider::Gemini | Provider::Google | Provider::Vertex => {
691 let settings: GeminiSettings = serde_json::from_value(value)?;
692 Ok(ModelSettings::GoogleChat(settings))
693 }
694 _ => Err(TypeError::Error(format!(
695 "Settings not supported for provider: {:?}",
696 provider
697 ))),
698 }
699 }
700
701 pub fn response_json_schema(&self) -> Option<&Value> {
702 self.request.response_json_schema()
703 }
704
705 pub fn new_rs(
706 messages: Vec<MessageNum>,
707 model: &str,
708 provider: Provider,
709 system_instructions: Vec<MessageNum>,
710 model_settings: Option<ModelSettings>,
711 response_json_schema: Option<Value>,
712 response_type: ResponseType,
713 ) -> Result<Self, TypeError> {
714 let model = model.to_string();
715 let version = potato_util::version();
717 let model_settings = match model_settings {
719 Some(settings) => {
720 settings.validate_provider(&provider)?;
722 settings
723 }
724 None => ModelSettings::provider_default_settings(&provider),
725 };
726
727 let parameters = Self::extract_variables(&messages, &system_instructions);
729
730 let request = to_provider_request(
732 messages,
733 system_instructions,
734 model.clone(),
735 model_settings,
736 response_json_schema,
737 )?;
738
739 Ok(Self {
740 request,
741 version,
742 parameters,
743 response_type,
744 model,
745 provider,
746 })
747 }
748
749 fn is_google_provider(&self) -> bool {
750 matches!(
751 self.provider,
752 Provider::Google | Provider::Gemini | Provider::Vertex
753 )
754 }
755
756 pub fn add_tools(&mut self, tools: Vec<AgentToolDefinition>) -> Result<(), TypeError> {
757 self.request.add_tools(tools)
758 }
759
760 pub fn extract_variables(
761 messages: &[MessageNum],
762 system_instructions: &[MessageNum],
763 ) -> Vec<String> {
764 let mut variables = BTreeSet::new();
765
766 for msg in system_instructions {
768 variables.extend(msg.extract_variables());
769 }
770
771 for msg in messages {
773 variables.extend(msg.extract_variables());
774 }
775
776 variables.into_iter().collect()
777 }
778
779 pub fn model_dump_value(&self) -> Value {
780 serde_json::to_value(self).unwrap_or(Value::Null)
782 }
783
784 pub fn to_request_json(&self) -> Result<Value, TypeError> {
785 let json_value = serde_json::to_value(self)?;
787
788 Ok(json_value)
789 }
790
791 pub fn set_response_json_schema(
792 &mut self,
793 response_json_schema: Option<Value>,
794 response_type: ResponseType,
795 ) {
796 self.request.set_response_json_schema(response_json_schema);
797 self.response_type = response_type;
798 }
799}
800
801#[cfg(test)]
803mod tests {
804 use super::*;
805 use crate::anthropic::v1::request::{
806 Base64ImageSource, Base64PDFSource, ContentBlockParam, DocumentBlockParam, ImageBlockParam,
807 MessageParam, PlainTextSource, TextBlockParam, UrlImageSource, UrlPDFSource,
808 };
809 use crate::google::{DataNum, GeminiContent, Part};
810 use crate::openai::v1::chat::request::{
811 ChatMessage as OpenAIChatMessage, ContentPart, FileContentPart, ImageContentPart,
812 TextContentPart,
813 };
814 use crate::prompt::types::Score;
815 use crate::StructuredOutput;
816
817 fn create_openai_chat_message() -> OpenAIChatMessage {
818 let text_part = TextContentPart::new("What company is this logo from?".to_string());
819 let text_content_part = ContentPart::Text(text_part);
820 OpenAIChatMessage {
821 role: "user".to_string(),
822 content: vec![text_content_part],
823 name: None,
824 }
825 }
826
827 fn create_system_openai_chat_message() -> OpenAIChatMessage {
828 let text_part = TextContentPart::new("system_prompt".to_string());
829 let text_content_part = ContentPart::Text(text_part);
830 OpenAIChatMessage {
831 role: "developer".to_string(),
832 content: vec![text_content_part],
833 name: None,
834 }
835 }
836
837 fn create_openai_image_message() -> OpenAIChatMessage {
838 let image_part = ImageContentPart::new("https://iili.io/3Hs4FMg.png".to_string(), None);
839 let image_content_part = ContentPart::ImageUrl(image_part);
840 OpenAIChatMessage {
841 role: "user".to_string(),
842 content: vec![image_content_part],
843 name: None,
844 }
845 }
846
847 fn create_openai_file_message() -> OpenAIChatMessage {
848 let file_part = FileContentPart::new(
849 Some("filedata".to_string()),
850 Some("fileid".to_string()),
851 Some("filename".to_string()),
852 );
853 let file_content_part = ContentPart::FileContent(file_part);
854 OpenAIChatMessage {
855 role: "user".to_string(),
856 content: vec![file_content_part],
857 name: None,
858 }
859 }
860
861 fn create_anthropic_text_message() -> MessageParam {
862 let text_block =
863 TextBlockParam::new_rs("What company is this logo from?".to_string(), None, None);
864 MessageParam {
865 role: "user".to_string(),
866 content: vec![ContentBlockParam {
867 inner: crate::anthropic::v1::request::ContentBlock::Text(text_block),
868 }],
869 }
870 }
871
872 fn create_anthropic_system_message() -> MessageParam {
873 let text_block = TextBlockParam::new_rs("system_prompt".to_string(), None, None);
874 MessageParam {
875 role: "assistant".to_string(),
876 content: vec![ContentBlockParam {
877 inner: crate::anthropic::v1::request::ContentBlock::Text(text_block),
878 }],
879 }
880 }
881
882 fn create_anthropic_base64_image_message() -> MessageParam {
883 let image_source =
884 Base64ImageSource::new("image/png".to_string(), "base64data".to_string()).unwrap();
885 let image_block = ImageBlockParam {
886 source: crate::anthropic::v1::request::ImageSource::Base64(image_source),
887 cache_control: None,
888 r#type: "image".to_string(),
889 };
890 MessageParam {
891 role: "user".to_string(),
892 content: vec![ContentBlockParam {
893 inner: crate::anthropic::v1::request::ContentBlock::Image(image_block),
894 }],
895 }
896 }
897
898 fn create_anthropic_url_image_message() -> MessageParam {
899 let image_source = UrlImageSource::new("https://iili.io/3Hs4FMg.png".to_string());
900 let image_block = ImageBlockParam {
901 source: crate::anthropic::v1::request::ImageSource::Url(image_source),
902 cache_control: None,
903 r#type: "image".to_string(),
904 };
905 MessageParam {
906 role: "user".to_string(),
907 content: vec![ContentBlockParam {
908 inner: crate::anthropic::v1::request::ContentBlock::Image(image_block),
909 }],
910 }
911 }
912
913 fn create_anthropic_base64_pdf_message() -> MessageParam {
914 let pdf_source = Base64PDFSource::new("base64pdfdata".to_string()).unwrap();
915 let document_block = DocumentBlockParam {
916 source: crate::anthropic::v1::request::DocumentSource::Base64(pdf_source),
917 cache_control: None,
918 title: Some("test_document.pdf".to_string()),
919 context: None,
920 r#type: "document".to_string(),
921 citations: None,
922 };
923 MessageParam {
924 role: "user".to_string(),
925 content: vec![ContentBlockParam {
926 inner: crate::anthropic::v1::request::ContentBlock::Document(document_block),
927 }],
928 }
929 }
930
931 fn create_anthropic_url_pdf_message() -> MessageParam {
932 let pdf_source = UrlPDFSource::new("https://example.com/document.pdf".to_string());
933 let document_block = DocumentBlockParam {
934 source: crate::anthropic::v1::request::DocumentSource::Url(pdf_source),
935 cache_control: None,
936 title: Some("test_document.pdf".to_string()),
937 context: None,
938 r#type: "document".to_string(),
939 citations: None,
940 };
941 MessageParam {
942 role: "user".to_string(),
943 content: vec![ContentBlockParam {
944 inner: crate::anthropic::v1::request::ContentBlock::Document(document_block),
945 }],
946 }
947 }
948
949 fn create_anthropic_plain_text_document_message() -> MessageParam {
950 let text_source = PlainTextSource::new("Plain text document content".to_string());
951 let document_block = DocumentBlockParam {
952 source: crate::anthropic::v1::request::DocumentSource::Text(text_source),
953 cache_control: None,
954 title: Some("text_document.txt".to_string()),
955 context: Some("Context for the document".to_string()),
956 r#type: "document".to_string(),
957 citations: None,
958 };
959 MessageParam {
960 role: "user".to_string(),
961 content: vec![ContentBlockParam {
962 inner: crate::anthropic::v1::request::ContentBlock::Document(document_block),
963 }],
964 }
965 }
966
967 #[test]
968 fn test_task_list_add_and_get() {
969 let text_part = TextContentPart::new("Test prompt. ${param1} ${param2}".to_string());
970 let content_part = ContentPart::Text(text_part);
971 let message = OpenAIChatMessage {
972 role: "user".to_string(),
973 content: vec![content_part],
974 name: None,
975 };
976
977 let prompt = Prompt::new_rs(
978 vec![MessageNum::OpenAIMessageV1(message)],
979 "gpt-4o",
980 Provider::OpenAI,
981 vec![],
982 None,
983 None,
984 ResponseType::Null,
985 )
986 .unwrap();
987
988 assert_eq!(prompt.request.messages().len(), 1);
990
991 assert!(prompt.parameters.len() == 2);
993
994 let mut parameters = prompt.parameters.clone();
996 parameters.sort();
997
998 assert_eq!(parameters[0], "param1");
999 assert_eq!(parameters[1], "param2");
1000
1001 let bound_msg = prompt.request.messages()[0]
1003 .bind("param1", "Value1")
1004 .unwrap();
1005 let bound_msg = bound_msg.bind("param2", "Value2").unwrap();
1006
1007 match bound_msg.clone() {
1009 MessageNum::OpenAIMessageV1(msg) => {
1010 if let ContentPart::Text(text_part) = &msg.content[0] {
1011 assert_eq!(text_part.text, "Test prompt. Value1 Value2");
1012 } else {
1013 panic!("Expected TextContentPart");
1014 }
1015 }
1016 _ => panic!("Expected OpenAIMessageV1"),
1017 }
1018 }
1019
1020 #[test]
1021 fn test_image_prompt() {
1022 let text_message = create_openai_chat_message();
1023 let image_message = create_openai_image_message();
1024
1025 let system_text_part = TextContentPart::new("system_prompt".to_string());
1026 let system_text_content_part = ContentPart::Text(system_text_part);
1027
1028 let system_text_message = OpenAIChatMessage {
1029 role: "assistant".to_string(),
1030 content: vec![system_text_content_part],
1031 name: None,
1032 };
1033
1034 let prompt = Prompt::new_rs(
1035 vec![
1036 MessageNum::OpenAIMessageV1(text_message),
1037 MessageNum::OpenAIMessageV1(image_message),
1038 ],
1039 "gpt-4o",
1040 Provider::OpenAI,
1041 vec![MessageNum::OpenAIMessageV1(system_text_message)],
1042 None,
1043 None,
1044 ResponseType::Null,
1045 )
1046 .unwrap();
1047
1048 if let MessageNum::OpenAIMessageV1(msg) = &prompt.request.messages()[1] {
1050 if let ContentPart::Text(text_part) = &msg.content[0] {
1051 assert_eq!(text_part.text, "What company is this logo from?");
1052 } else {
1053 panic!("Expected TextContentPart for the first user message");
1054 }
1055 } else {
1056 panic!("Expected OpenAIMessageV1 for the first user message");
1057 }
1058
1059 if let MessageNum::OpenAIMessageV1(msg) = &prompt.request.messages()[2] {
1061 if let ContentPart::ImageUrl(image_url) = &msg.content[0] {
1062 assert_eq!(image_url.image_url.url, "https://iili.io/3Hs4FMg.png");
1063 assert_eq!(image_url.r#type, "image_url");
1064 } else {
1065 panic!("Expected ContentPart::Image for the second user message");
1066 }
1067 } else {
1068 panic!("Expected OpenAIMessageV1 for the second user message");
1069 }
1070 }
1071
1072 #[test]
1073 fn test_document_prompt() {
1074 let text_message = create_openai_chat_message();
1075 let file_message = create_openai_file_message();
1076 let system_message = create_system_openai_chat_message();
1077
1078 let prompt = Prompt::new_rs(
1079 vec![
1080 MessageNum::OpenAIMessageV1(text_message),
1081 MessageNum::OpenAIMessageV1(file_message),
1082 ],
1083 "gpt-4o",
1084 Provider::OpenAI,
1085 vec![MessageNum::OpenAIMessageV1(system_message)],
1086 None,
1087 None,
1088 ResponseType::Null,
1089 )
1090 .unwrap();
1091
1092 if let MessageNum::OpenAIMessageV1(msg) = &prompt.request.messages()[2] {
1094 if let ContentPart::FileContent(file_content) = &msg.content[0] {
1095 assert_eq!(file_content.file.file_id.as_ref().unwrap(), "fileid");
1096 assert_eq!(file_content.file.filename.as_ref().unwrap(), "filename");
1097 } else {
1098 panic!("Expected ContentPart::FileContent for the second user message");
1099 }
1100 } else {
1101 panic!("Expected OpenAIMessageV1 for the first user message");
1102 }
1103 }
1104
1105 #[test]
1106 fn test_response_format_score() {
1107 let text_message = create_openai_chat_message();
1108 let prompt = Prompt::new_rs(
1109 vec![MessageNum::OpenAIMessageV1(text_message)],
1110 "gpt-4o",
1111 Provider::OpenAI,
1112 vec![],
1113 None,
1114 Some(Score::get_structured_output_schema()),
1115 ResponseType::Null,
1116 )
1117 .unwrap();
1118
1119 assert!(prompt.response_json_schema().is_some());
1121 }
1122
1123 #[test]
1124 fn test_anthropic_text_message_binding() {
1125 let text_block =
1126 TextBlockParam::new_rs("Test prompt. ${param1} ${param2}".to_string(), None, None);
1127 let message = MessageParam {
1128 role: "user".to_string(),
1129 content: vec![ContentBlockParam {
1130 inner: crate::anthropic::v1::request::ContentBlock::Text(text_block),
1131 }],
1132 };
1133
1134 let prompt = Prompt::new_rs(
1135 vec![MessageNum::AnthropicMessageV1(message)],
1136 "claude-3-5-sonnet-20241022",
1137 Provider::Anthropic,
1138 vec![],
1139 None,
1140 None,
1141 ResponseType::Null,
1142 )
1143 .unwrap();
1144
1145 assert_eq!(prompt.request.messages().len(), 1);
1146 assert_eq!(prompt.parameters.len(), 2);
1147
1148 let mut parameters = prompt.parameters.clone();
1149 parameters.sort();
1150 assert_eq!(parameters[0], "param1");
1151 assert_eq!(parameters[1], "param2");
1152
1153 let bound_msg = prompt.request.messages()[0]
1155 .bind("param1", "Value1")
1156 .unwrap();
1157 let bound_msg = bound_msg.bind("param2", "Value2").unwrap();
1158
1159 match bound_msg {
1160 MessageNum::AnthropicMessageV1(msg) => {
1161 if let crate::anthropic::v1::request::ContentBlock::Text(text_block) =
1162 &msg.content[0].inner
1163 {
1164 assert_eq!(text_block.text, "Test prompt. Value1 Value2");
1165 } else {
1166 panic!("Expected TextBlockParam");
1167 }
1168 }
1169 _ => panic!("Expected AnthropicMessageV1"),
1170 }
1171 }
1172
1173 #[test]
1174 fn test_anthropic_url_image_prompt() {
1175 let text_message = create_anthropic_text_message();
1176 let image_message = create_anthropic_url_image_message();
1177 let system_message = create_anthropic_system_message();
1178
1179 let prompt = Prompt::new_rs(
1180 vec![
1181 MessageNum::AnthropicMessageV1(text_message),
1182 MessageNum::AnthropicMessageV1(image_message),
1183 ],
1184 "claude-3-5-sonnet-20241022",
1185 Provider::Anthropic,
1186 vec![MessageNum::AnthropicMessageV1(system_message)],
1187 None,
1188 None,
1189 ResponseType::Null,
1190 )
1191 .unwrap();
1192
1193 if let MessageNum::AnthropicMessageV1(msg) = &prompt.request.messages()[0] {
1195 if let crate::anthropic::v1::request::ContentBlock::Text(text_block) =
1196 &msg.content[0].inner
1197 {
1198 assert_eq!(text_block.text, "What company is this logo from?");
1199 } else {
1200 panic!("Expected TextBlock for first message");
1201 }
1202 } else {
1203 panic!("Expected AnthropicMessageV1");
1204 }
1205
1206 if let MessageNum::AnthropicMessageV1(msg) = &prompt.request.messages()[1] {
1208 if let crate::anthropic::v1::request::ContentBlock::Image(image_block) =
1209 &msg.content[0].inner
1210 {
1211 match &image_block.source {
1212 crate::anthropic::v1::request::ImageSource::Url(url_source) => {
1213 assert_eq!(url_source.url, "https://iili.io/3Hs4FMg.png");
1214 assert_eq!(url_source.r#type, "url");
1215 }
1216 _ => panic!("Expected URL image source"),
1217 }
1218 assert_eq!(image_block.r#type, "image");
1219 } else {
1220 panic!("Expected ImageBlock for second message");
1221 }
1222 } else {
1223 panic!("Expected AnthropicMessageV1");
1224 }
1225 }
1226
1227 #[test]
1228 fn test_anthropic_base64_image_prompt() {
1229 let text_message = create_anthropic_text_message();
1230 let image_message = create_anthropic_base64_image_message();
1231
1232 let prompt = Prompt::new_rs(
1233 vec![
1234 MessageNum::AnthropicMessageV1(text_message),
1235 MessageNum::AnthropicMessageV1(image_message),
1236 ],
1237 "claude-3-5-sonnet-20241022",
1238 Provider::Anthropic,
1239 vec![],
1240 None,
1241 None,
1242 ResponseType::Null,
1243 )
1244 .unwrap();
1245
1246 if let MessageNum::AnthropicMessageV1(msg) = &prompt.request.messages()[1] {
1248 if let crate::anthropic::v1::request::ContentBlock::Image(image_block) =
1249 &msg.content[0].inner
1250 {
1251 match &image_block.source {
1252 crate::anthropic::v1::request::ImageSource::Base64(base64_source) => {
1253 assert_eq!(base64_source.media_type, "image/png");
1254 assert_eq!(base64_source.data, "base64data");
1255 assert_eq!(base64_source.r#type, "base64");
1256 }
1257 _ => panic!("Expected Base64 image source"),
1258 }
1259 } else {
1260 panic!("Expected ImageBlock");
1261 }
1262 } else {
1263 panic!("Expected AnthropicMessageV1");
1264 }
1265 }
1266
1267 #[test]
1269 fn test_anthropic_base64_pdf_document_prompt() {
1270 let text_message = create_anthropic_text_message();
1271 let pdf_message = create_anthropic_base64_pdf_message();
1272 let system_message = create_anthropic_system_message();
1273
1274 let prompt = Prompt::new_rs(
1275 vec![
1276 MessageNum::AnthropicMessageV1(text_message),
1277 MessageNum::AnthropicMessageV1(pdf_message),
1278 ],
1279 "claude-3-5-sonnet-20241022",
1280 Provider::Anthropic,
1281 vec![MessageNum::AnthropicMessageV1(system_message)],
1282 None,
1283 None,
1284 ResponseType::Null,
1285 )
1286 .unwrap();
1287
1288 if let MessageNum::AnthropicMessageV1(msg) = &prompt.request.messages()[1] {
1290 if let crate::anthropic::v1::request::ContentBlock::Document(document_block) =
1291 &msg.content[0].inner
1292 {
1293 match &document_block.source {
1294 crate::anthropic::v1::request::DocumentSource::Base64(pdf_source) => {
1295 assert_eq!(pdf_source.media_type, "application/pdf");
1296 assert_eq!(pdf_source.data, "base64pdfdata");
1297 assert_eq!(pdf_source.r#type, "base64");
1298 }
1299 _ => panic!("Expected Base64 PDF source"),
1300 }
1301 assert_eq!(document_block.r#type, "document");
1302 assert_eq!(document_block.title.as_ref().unwrap(), "test_document.pdf");
1303 } else {
1304 panic!("Expected DocumentBlock");
1305 }
1306 } else {
1307 panic!("Expected AnthropicMessageV1");
1308 }
1309 }
1310
1311 #[test]
1313 fn test_anthropic_url_pdf_document_prompt() {
1314 let text_message = create_anthropic_text_message();
1315 let pdf_message = create_anthropic_url_pdf_message();
1316
1317 let prompt = Prompt::new_rs(
1318 vec![
1319 MessageNum::AnthropicMessageV1(text_message),
1320 MessageNum::AnthropicMessageV1(pdf_message),
1321 ],
1322 "claude-3-5-sonnet-20241022",
1323 Provider::Anthropic,
1324 vec![],
1325 None,
1326 None,
1327 ResponseType::Null,
1328 )
1329 .unwrap();
1330
1331 if let MessageNum::AnthropicMessageV1(msg) = &prompt.request.messages()[1] {
1333 if let crate::anthropic::v1::request::ContentBlock::Document(document_block) =
1334 &msg.content[0].inner
1335 {
1336 match &document_block.source {
1337 crate::anthropic::v1::request::DocumentSource::Url(url_source) => {
1338 assert_eq!(url_source.url, "https://example.com/document.pdf");
1339 assert_eq!(url_source.r#type, "url");
1340 }
1341 _ => panic!("Expected URL PDF source"),
1342 }
1343 } else {
1344 panic!("Expected DocumentBlock");
1345 }
1346 } else {
1347 panic!("Expected AnthropicMessageV1");
1348 }
1349 }
1350
1351 #[test]
1353 fn test_anthropic_plain_text_document_prompt() {
1354 let text_message = create_anthropic_text_message();
1355 let text_doc_message = create_anthropic_plain_text_document_message();
1356
1357 let prompt = Prompt::new_rs(
1358 vec![
1359 MessageNum::AnthropicMessageV1(text_message),
1360 MessageNum::AnthropicMessageV1(text_doc_message),
1361 ],
1362 "claude-3-5-sonnet-20241022",
1363 Provider::Anthropic,
1364 vec![],
1365 None,
1366 None,
1367 ResponseType::Null,
1368 )
1369 .unwrap();
1370
1371 if let MessageNum::AnthropicMessageV1(msg) = &prompt.request.messages()[1] {
1373 if let crate::anthropic::v1::request::ContentBlock::Document(document_block) =
1374 &msg.content[0].inner
1375 {
1376 match &document_block.source {
1377 crate::anthropic::v1::request::DocumentSource::Text(text_source) => {
1378 assert_eq!(text_source.media_type, "text/plain");
1379 assert_eq!(text_source.data, "Plain text document content");
1380 assert_eq!(text_source.r#type, "text");
1381 }
1382 _ => panic!("Expected Text document source"),
1383 }
1384 assert_eq!(
1385 document_block.context.as_ref().unwrap(),
1386 "Context for the document"
1387 );
1388 } else {
1389 panic!("Expected DocumentBlock");
1390 }
1391 } else {
1392 panic!("Expected AnthropicMessageV1");
1393 }
1394 }
1395
1396 #[test]
1398 fn test_anthropic_mixed_content_prompt() {
1399 let text_message = create_anthropic_text_message();
1400 let pdf_message = create_anthropic_base64_pdf_message();
1401 let text_doc_message = create_anthropic_plain_text_document_message();
1402 let system_message = create_anthropic_system_message();
1403
1404 let prompt = Prompt::new_rs(
1405 vec![
1406 MessageNum::AnthropicMessageV1(text_message),
1407 MessageNum::AnthropicMessageV1(pdf_message),
1408 MessageNum::AnthropicMessageV1(text_doc_message),
1409 ],
1410 "claude-3-5-sonnet-20241022",
1411 Provider::Anthropic,
1412 vec![MessageNum::AnthropicMessageV1(system_message)],
1413 None,
1414 None,
1415 ResponseType::Null,
1416 )
1417 .unwrap();
1418
1419 assert_eq!(prompt.request.messages().len(), 3);
1420 assert_eq!(prompt.request.system_instructions().len(), 1);
1421 assert_eq!(prompt.provider, Provider::Anthropic);
1422 assert_eq!(prompt.model, "claude-3-5-sonnet-20241022");
1423 }
1424
1425 #[test]
1427 fn test_gemini_chat_message() {
1428 let text = Part::from_text("Test prompt. ${param1} ${param2}".to_string());
1429 let message = GeminiContent {
1430 role: "user".to_string(),
1431 parts: vec![text],
1432 };
1433
1434 let prompt = Prompt::new_rs(
1435 vec![MessageNum::GeminiContentV1(message)],
1436 "gemini-1.5-pro",
1437 Provider::Google,
1438 vec![],
1439 None,
1440 None,
1441 ResponseType::Null,
1442 )
1443 .unwrap();
1444
1445 assert_eq!(prompt.request.messages().len(), 1);
1446 assert_eq!(prompt.parameters.len(), 2);
1447
1448 let mut parameters = prompt.parameters.clone();
1449 parameters.sort();
1450 assert_eq!(parameters[0], "param1");
1451 assert_eq!(parameters[1], "param2");
1452
1453 let bound_msg = prompt.request.messages()[0]
1455 .bind("param1", "Value1")
1456 .unwrap();
1457 let bound_msg = bound_msg.bind("param2", "Value2").unwrap();
1458
1459 match bound_msg {
1460 MessageNum::GeminiContentV1(msg) => {
1461 if let DataNum::Text(text_part) = &msg.parts[0].data {
1462 assert_eq!(text_part, "Test prompt. Value1 Value2");
1463 } else {
1464 panic!("Expected Text Part");
1465 }
1466 }
1467 _ => panic!("Expected GeminiContentV1"),
1468 }
1469 }
1470}