1use crate::anthropic::v1::request::{AnthropicSettings, MessageParam as AnthropicMessage};
2use crate::error::TypeError;
3use crate::google::v1::generate::request::{GeminiContent, GeminiSettings};
4use crate::openai::v1::chat::request::ChatMessage as OpenAIChatMessage;
5use crate::openai::v1::chat::settings::OpenAIChatSettings;
6use crate::prompt::builder::{to_provider_request, ProviderRequest};
7use crate::prompt::settings::ModelSettings;
8use crate::prompt::types::parse_response_to_json;
9use crate::prompt::types::ResponseType;
10use crate::prompt::types::Role;
11use crate::prompt::{AnthropicMessageList, GeminiContentList, MessageNum, OpenAIMessageList};
12use crate::tools::AgentToolDefinition;
13use crate::traits::MessageFactory;
14use crate::SettingsType;
15use crate::{Provider, SaveName};
16use potato_util::utils::extract_string_value;
17use potato_util::PyHelperFuncs;
18use potatohead_macro::try_extract_message;
19use pyo3::prelude::*;
20use pyo3::types::{PyDict, PyList, PyString, PyTuple};
21use pythonize::pythonize;
22use serde::{Deserialize, Serialize};
23use serde_json::Value;
24use std::collections::HashSet;
25use std::path::PathBuf;
26
27#[derive(Debug, Deserialize)]
41struct GenericPromptConfig {
42 model: String,
43 provider: String,
44 messages: Vec<String>, #[serde(default)]
46 system_instructions: Option<Vec<String>>,
47 #[serde(default)]
48 settings: Option<Value>,
49}
50
51fn create_message_for_provider(
52 content: String,
53 provider: &Provider,
54 role: &str,
55) -> Result<MessageNum, TypeError> {
56 match provider {
57 Provider::OpenAI => {
58 OpenAIChatMessage::from_text(content, role).map(MessageNum::OpenAIMessageV1)
59 }
60 Provider::Anthropic => {
61 AnthropicMessage::from_text(content, role).map(MessageNum::AnthropicMessageV1)
62 }
63 Provider::Gemini | Provider::Google | Provider::Vertex => {
64 GeminiContent::from_text(content, role).map(MessageNum::GeminiContentV1)
65 }
66 _ => Err(TypeError::Error(format!(
67 "Unsupported provider for message creation: {:?}",
68 provider
69 ))),
70 }
71}
72
73fn parse_single_message(
74 message: &Bound<'_, PyAny>,
75 provider: &Provider,
76 default_role: &str,
77) -> Result<MessageNum, TypeError> {
78 if message.is_instance_of::<PyString>() {
80 let text = message.extract::<String>()?;
81 return create_message_for_provider(text, provider, default_role);
82 }
83
84 try_extract_message!(
86 message,
87 OpenAIChatMessage => MessageNum::OpenAIMessageV1,
88 AnthropicMessage => MessageNum::AnthropicMessageV1,
89 GeminiContent => MessageNum::GeminiContentV1,
90 );
91
92 Err(TypeError::InvalidMessageTypeInList(
93 message.get_type().name()?.to_string(),
94 ))
95}
96
97fn parse_messages(
98 messages: &Bound<'_, PyAny>,
99 provider: &Provider,
100 default_role: &str,
101) -> Result<Vec<MessageNum>, TypeError> {
102 let mut messages =
104 if !messages.is_instance_of::<PyList>() && !messages.is_instance_of::<PyTuple>() {
105 vec![parse_single_message(messages, provider, default_role)?]
106 } else {
107 messages
109 .try_iter()?
110 .map(|item| {
111 let item = item?;
112 parse_single_message(&item, provider, default_role)
113 })
114 .collect::<Result<Vec<_>, _>>()?
115 };
116
117 if provider == &Provider::Anthropic
120 && (default_role == Role::System.as_str()
121 || default_role == Role::Assistant.as_str()
122 || default_role == Role::Developer.as_str())
123 {
124 for msg in messages.iter_mut() {
125 msg.anthropic_message_to_system_message()?;
126 }
127 }
128
129 Ok(messages)
130}
131
132fn get_system_role(provider: &Provider) -> &'static str {
133 match provider {
134 Provider::OpenAI => Role::Developer.into(),
135 Provider::Gemini | Provider::Vertex | Provider::Google => Role::Model.into(),
136 Provider::Anthropic => Role::System.into(),
137 _ => Role::System.into(),
138 }
139}
140
141pub fn extract_system_instructions(
143 system_instruction: Option<&Bound<'_, PyAny>>,
144 provider: &Provider,
145) -> Result<Option<Vec<MessageNum>>, TypeError> {
146 let system_instructions = if let Some(sys_inst) = system_instruction {
147 Some(parse_messages(
148 sys_inst,
149 provider,
150 get_system_role(provider),
151 )?)
152 } else {
153 None
154 };
155
156 Ok(system_instructions)
157}
158
159#[pyclass]
160#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
161pub struct Prompt {
162 pub request: ProviderRequest,
163
164 #[pyo3(get)]
165 pub model: String,
166
167 #[pyo3(get)]
168 pub provider: Provider,
169
170 pub version: String,
171
172 #[pyo3(get)]
173 #[serde(default)]
174 pub parameters: Vec<String>,
175
176 #[serde(default)]
177 pub response_type: ResponseType,
178}
179
180fn extract_model_settings(model_settings: &Bound<'_, PyAny>) -> Result<ModelSettings, TypeError> {
182 let settings_type = model_settings
183 .call_method0("settings_type")?
184 .extract::<SettingsType>()?;
185
186 match settings_type {
187 SettingsType::OpenAIChat => model_settings
188 .extract::<OpenAIChatSettings>()
189 .map(ModelSettings::OpenAIChat),
190 SettingsType::GoogleChat => model_settings
191 .extract::<GeminiSettings>()
192 .map(ModelSettings::GoogleChat),
193 SettingsType::Anthropic => model_settings
194 .extract::<AnthropicSettings>()
195 .map(ModelSettings::AnthropicChat),
196 SettingsType::ModelSettings => model_settings.extract::<ModelSettings>(),
197 }
198 .map_err(Into::into)
199}
200
201#[pymethods]
202impl Prompt {
203 #[new]
219 #[pyo3(signature = (messages, model, provider, system_instructions=None, model_settings=None, output_type=None))]
220 pub fn new(
221 py: Python<'_>,
222 messages: &Bound<'_, PyAny>,
223 model: &str,
224 provider: &Bound<'_, PyAny>,
225 system_instructions: Option<&Bound<'_, PyAny>>,
226 model_settings: Option<&Bound<'_, PyAny>>,
227 output_type: Option<&Bound<'_, PyAny>>, ) -> Result<Self, TypeError> {
229 let model_settings = model_settings
231 .as_ref()
232 .map(|s| extract_model_settings(s))
233 .transpose()?;
234
235 let provider = Provider::extract_provider(provider)?;
237
238 let messages = parse_messages(messages, &provider, Role::User.into())?;
241 let system_instructions = if let Some(sys_inst) = system_instructions {
242 parse_messages(sys_inst, &provider, get_system_role(&provider))?
243 } else {
244 vec![]
245 };
246
247 let (response_type, response_json_schema) = match output_type {
249 Some(output_type) => {
250 parse_response_to_json(py, output_type)?
252 }
253 None => (ResponseType::Null, None),
254 };
255
256 Self::new_rs(
257 messages,
258 model,
259 provider,
260 system_instructions,
261 model_settings,
262 response_json_schema,
263 response_type,
264 )
265 }
266
267 #[getter]
268 pub fn model_settings<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyAny>, TypeError> {
269 self.request.model_settings(py)
270 }
271
272 #[getter]
273 pub fn model_identifier(&self) -> String {
274 format!("{}:{}", self.provider.as_str(), self.model)
275 }
276
277 #[pyo3(signature = (path = None))]
278 pub fn save_prompt(&self, path: Option<PathBuf>) -> PyResult<PathBuf> {
279 let save_path = path.unwrap_or_else(|| PathBuf::from(SaveName::Prompt));
280 PyHelperFuncs::save_to_json(self, &save_path)?;
281 Ok(save_path)
282 }
283
284 #[staticmethod]
285 pub fn from_path(path: PathBuf) -> Result<Self, TypeError> {
286 let content = std::fs::read_to_string(&path)?;
287
288 let extension = path
289 .extension()
290 .and_then(|ext| ext.to_str())
291 .ok_or_else(|| TypeError::Error(format!("Invalid file path: {:?}", path)))?;
292
293 let mut prompt = match extension.to_lowercase().as_str() {
294 "json" => {
295 if let Ok(config) = serde_json::from_str::<GenericPromptConfig>(&content) {
296 Self::from_generic_config(config)?
297 } else {
298 serde_json::from_str(&content)?
299 }
300 }
301 "yaml" | "yml" => {
302 if let Ok(config) = serde_yaml::from_str::<GenericPromptConfig>(&content) {
303 Self::from_generic_config(config)?
304 } else {
305 serde_yaml::from_str(&content)?
306 }
307 }
308 _ => {
309 return Err(TypeError::Error(format!(
310 "Unsupported file extension '{}'. Expected .json, .yaml, or .yml",
311 extension
312 )))
313 }
314 };
315
316 if prompt.parameters.is_empty() {
317 let system_instructions: Vec<MessageNum> = prompt
318 .request
319 .system_instructions()
320 .iter()
321 .map(|msg| (*msg).clone())
322 .collect();
323 let parameters =
324 Self::extract_variables(prompt.request.messages(), &system_instructions);
325 prompt.parameters = parameters;
326 }
327
328 Ok(prompt)
329 }
330
331 #[staticmethod]
332 pub fn model_validate_json(json_string: String) -> Result<Self, TypeError> {
333 let json_value: Value = serde_json::from_str(&json_string)?;
334 let model: Self = serde_json::from_value(json_value)?;
335
336 Ok(model)
337 }
338
339 pub fn model_dump_json(&self) -> String {
340 serde_json::to_string(self).unwrap()
341 }
342
343 pub fn __str__(&self) -> String {
344 PyHelperFuncs::__str__(self)
345 }
346
347 #[getter]
348 pub fn all_messages<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyList>, TypeError> {
350 self.request.get_all_py_messages(py)
351 }
352
353 #[getter]
354 pub fn messages<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyList>, TypeError> {
356 self.request.get_py_messages(py)
357 }
358
359 #[getter]
360 pub fn message<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyAny>, TypeError> {
362 self.request.get_py_message(py)
363 }
364
365 #[getter]
366 pub fn openai_messages(&self) -> Result<OpenAIMessageList, TypeError> {
369 if self.provider != Provider::OpenAI {
370 return Err(TypeError::Error(
371 "Prompt provider is not OpenAI".to_string(),
372 ));
373 }
374 let messages = self
375 .request
376 .messages()
377 .iter()
378 .filter(|msg| msg.is_user_message())
379 .filter_map(|msg| match msg {
380 MessageNum::OpenAIMessageV1(m) => Some(m.clone()),
381 _ => None,
382 })
383 .collect::<Vec<_>>();
384 Ok(OpenAIMessageList { messages })
385 }
386
387 #[getter]
388 pub fn openai_message(&self) -> Result<OpenAIChatMessage, TypeError> {
391 if self.provider != Provider::OpenAI {
392 return Err(TypeError::Error(
393 "Prompt provider is not OpenAI".to_string(),
394 ));
395 }
396 self.request.get_openai_message()
397 }
398
399 #[getter]
400 pub fn gemini_messages(&self) -> Result<GeminiContentList, TypeError> {
403 if !self.is_google_provider() {
404 return Err(TypeError::Error(
405 "Prompt provider is not Google, Gemini, or Vertex".to_string(),
406 ));
407 }
408 let messages = self
409 .request
410 .messages()
411 .iter()
412 .filter(|msg| msg.is_user_message())
413 .filter_map(|msg| match msg {
414 MessageNum::GeminiContentV1(m) => Some(m.clone()),
415 _ => None,
416 })
417 .collect::<Vec<_>>();
418
419 Ok(GeminiContentList { messages })
420 }
421
422 #[getter]
423 pub fn gemini_message(&self) -> Result<GeminiContent, TypeError> {
426 if !self.is_google_provider() {
427 return Err(TypeError::Error(
428 "Prompt provider is not Google, Gemini, or Vertex".to_string(),
429 ));
430 }
431 self.request.get_gemini_message()
432 }
433
434 #[getter]
435 pub fn anthropic_messages(&self) -> Result<AnthropicMessageList, TypeError> {
436 if self.provider != Provider::Anthropic {
437 return Err(TypeError::Error(
438 "Prompt provider is not Anthropic".to_string(),
439 ));
440 }
441 let messages = self
442 .request
443 .messages()
444 .iter()
445 .filter(|msg| msg.is_user_message())
446 .filter_map(|msg| match msg {
447 MessageNum::AnthropicMessageV1(m) => Some(m.clone()),
448 _ => None,
449 })
450 .collect::<Vec<_>>();
451
452 Ok(AnthropicMessageList { messages })
453 }
454
455 #[getter]
456 pub fn anthropic_message(&self) -> Result<AnthropicMessage, TypeError> {
458 if self.provider != Provider::Anthropic {
459 return Err(TypeError::Error(
460 "Prompt provider is not Anthropic".to_string(),
461 ));
462 }
463 self.request.get_anthropic_message()
464 }
465
466 #[getter]
467 pub fn system_instructions<'py>(
468 &self,
469 py: Python<'py>,
470 ) -> Result<Bound<'py, PyList>, TypeError> {
471 self.request.get_py_system_instructions(py)
472 }
473
474 #[pyo3(signature = (name=None, value=None, **kwargs))]
482 pub fn bind(
483 &self,
484 name: Option<&str>,
485 value: Option<&Bound<'_, PyAny>>,
486 kwargs: Option<&Bound<'_, PyDict>>,
487 ) -> Result<Self, TypeError> {
488 let mut new_prompt = self.clone();
489
490 if let (Some(name), Some(value)) = (name, value) {
491 let var_value = extract_string_value(value)?;
492 for message in new_prompt.request.messages_mut() {
493 message.bind_mut(name, &var_value)?;
494 }
495 }
496
497 if let Some(kwargs) = kwargs {
498 for (key, val) in kwargs.iter() {
499 let var_name = key.extract::<String>()?;
500 let var_value = extract_string_value(&val)?;
501
502 for message in new_prompt.request.messages_mut() {
503 message.bind_mut(&var_name, &var_value)?;
504 }
505 }
506 }
507
508 if name.is_none() && kwargs.is_none_or(|k| k.is_empty()) {
509 return Err(TypeError::Error(
510 "Must provide either (name, value) or keyword arguments for binding".to_string(),
511 ));
512 }
513
514 Ok(new_prompt)
515 }
516
517 #[pyo3(signature = (name=None, value=None, **kwargs))]
524 pub fn bind_mut(
525 &mut self,
526 name: Option<&str>,
527 value: Option<&Bound<'_, PyAny>>,
528 kwargs: Option<&Bound<'_, PyDict>>,
529 ) -> Result<(), TypeError> {
530 if let (Some(name), Some(value)) = (name, value) {
531 let var_value = extract_string_value(value)?;
532 for message in self.request.messages_mut() {
533 message.bind_mut(name, &var_value)?;
534 }
535 }
536
537 if let Some(kwargs) = kwargs {
538 for (key, val) in kwargs.iter() {
539 let var_name = key.extract::<String>()?;
540 let var_value = extract_string_value(&val)?;
541
542 for message in self.request.messages_mut() {
543 message.bind_mut(&var_name, &var_value)?;
544 }
545 }
546 }
547
548 if name.is_none() && kwargs.is_none_or(|k| k.is_empty()) {
549 return Err(TypeError::Error(
550 "Must provide either (name, value) or keyword arguments for binding".to_string(),
551 ));
552 }
553
554 Ok(())
555 }
556
557 #[getter]
558 pub fn response_json_schema_pretty(&self) -> Option<String> {
559 Some(PyHelperFuncs::__str__(
560 self.request.response_json_schema().as_ref()?,
561 ))
562 }
563
564 #[getter]
565 #[pyo3(name = "response_json_schema")]
566 pub fn response_json_schema_py(&self) -> Option<String> {
567 Some(self.request.response_json_schema().as_ref()?.to_string())
568 }
569
570 pub fn model_dump<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyAny>, TypeError> {
571 let request = &self.request.to_json()?;
572 Ok(pythonize(py, request)?)
573 }
574}
575
576impl Prompt {
577 fn from_generic_config(config: GenericPromptConfig) -> Result<Self, TypeError> {
580 if config.messages.is_empty() {
582 return Err(TypeError::Error(
583 "Prompt has no messages. Generic prompt format requires at least one message."
584 .to_string(),
585 ));
586 }
587
588 let provider = Provider::from_string(&config.provider)?;
590
591 let messages: Vec<MessageNum> = config
593 .messages
594 .into_iter()
595 .map(|msg| create_message_for_provider(msg, &provider, Role::User.as_str()))
596 .collect::<Result<Vec<_>, _>>()?;
597
598 let system_instructions = if let Some(sys_inst) = config.system_instructions {
600 sys_inst
601 .into_iter()
602 .map(|msg| create_message_for_provider(msg, &provider, get_system_role(&provider)))
603 .collect::<Result<Vec<_>, _>>()?
604 } else {
605 Vec::new()
606 };
607
608 let model_settings = if let Some(settings) = config.settings {
610 Some(Self::settings_from_value(settings, &provider)?)
611 } else {
612 None
613 };
614
615 Self::new_rs(
617 messages,
618 &config.model,
619 provider,
620 system_instructions,
621 model_settings,
622 None,
623 ResponseType::Null,
624 )
625 }
626
627 fn settings_from_value(value: Value, provider: &Provider) -> Result<ModelSettings, TypeError> {
630 match provider {
631 Provider::OpenAI => {
632 let settings: OpenAIChatSettings = serde_json::from_value(value)?;
633 Ok(ModelSettings::OpenAIChat(settings))
634 }
635 Provider::Anthropic => {
636 let settings: AnthropicSettings = serde_json::from_value(value)?;
637 Ok(ModelSettings::AnthropicChat(settings))
638 }
639 Provider::Gemini | Provider::Google | Provider::Vertex => {
640 let settings: GeminiSettings = serde_json::from_value(value)?;
641 Ok(ModelSettings::GoogleChat(settings))
642 }
643 _ => Err(TypeError::Error(format!(
644 "Settings not supported for provider: {:?}",
645 provider
646 ))),
647 }
648 }
649
650 pub fn response_json_schema(&self) -> Option<&Value> {
651 self.request.response_json_schema()
652 }
653
654 pub fn new_rs(
655 messages: Vec<MessageNum>,
656 model: &str,
657 provider: Provider,
658 system_instructions: Vec<MessageNum>,
659 model_settings: Option<ModelSettings>,
660 response_json_schema: Option<Value>,
661 response_type: ResponseType,
662 ) -> Result<Self, TypeError> {
663 let model = model.to_string();
664 let version = potato_util::version();
666 let model_settings = match model_settings {
668 Some(settings) => {
669 settings.validate_provider(&provider)?;
671 settings
672 }
673 None => ModelSettings::provider_default_settings(&provider),
674 };
675
676 let parameters = Self::extract_variables(&messages, &system_instructions);
678
679 let request = to_provider_request(
681 messages,
682 system_instructions,
683 model.clone(),
684 model_settings,
685 response_json_schema,
686 )?;
687
688 Ok(Self {
689 request,
690 version,
691 parameters,
692 response_type,
693 model,
694 provider,
695 })
696 }
697
698 fn is_google_provider(&self) -> bool {
699 matches!(
700 self.provider,
701 Provider::Google | Provider::Gemini | Provider::Vertex
702 )
703 }
704
705 pub fn add_tools(&mut self, tools: Vec<AgentToolDefinition>) -> Result<(), TypeError> {
706 self.request.add_tools(tools)
707 }
708
709 pub fn extract_variables(
710 messages: &[MessageNum],
711 system_instructions: &[MessageNum],
712 ) -> Vec<String> {
713 let mut variables = HashSet::new();
714
715 for msg in system_instructions {
717 variables.extend(msg.extract_variables());
718 }
719
720 for msg in messages {
722 variables.extend(msg.extract_variables());
723 }
724
725 variables.into_iter().collect()
726 }
727
728 pub fn model_dump_value(&self) -> Value {
729 serde_json::to_value(self).unwrap_or(Value::Null)
731 }
732
733 pub fn to_request_json(&self) -> Result<Value, TypeError> {
734 let json_value = serde_json::to_value(self)?;
736
737 Ok(json_value)
738 }
739
740 pub fn set_response_json_schema(
741 &mut self,
742 response_json_schema: Option<Value>,
743 response_type: ResponseType,
744 ) {
745 self.request.set_response_json_schema(response_json_schema);
746 self.response_type = response_type;
747 }
748}
749
750#[cfg(test)]
752mod tests {
753 use super::*;
754 use crate::anthropic::v1::request::{
755 Base64ImageSource, Base64PDFSource, ContentBlockParam, DocumentBlockParam, ImageBlockParam,
756 MessageParam, PlainTextSource, TextBlockParam, UrlImageSource, UrlPDFSource,
757 };
758 use crate::google::{DataNum, GeminiContent, Part};
759 use crate::openai::v1::chat::request::{
760 ChatMessage as OpenAIChatMessage, ContentPart, FileContentPart, ImageContentPart,
761 TextContentPart,
762 };
763 use crate::prompt::types::Score;
764 use crate::StructuredOutput;
765
766 fn create_openai_chat_message() -> OpenAIChatMessage {
767 let text_part = TextContentPart::new("What company is this logo from?".to_string());
768 let text_content_part = ContentPart::Text(text_part);
769 OpenAIChatMessage {
770 role: "user".to_string(),
771 content: vec![text_content_part],
772 name: None,
773 }
774 }
775
776 fn create_system_openai_chat_message() -> OpenAIChatMessage {
777 let text_part = TextContentPart::new("system_prompt".to_string());
778 let text_content_part = ContentPart::Text(text_part);
779 OpenAIChatMessage {
780 role: "developer".to_string(),
781 content: vec![text_content_part],
782 name: None,
783 }
784 }
785
786 fn create_openai_image_message() -> OpenAIChatMessage {
787 let image_part = ImageContentPart::new("https://iili.io/3Hs4FMg.png".to_string(), None);
788 let image_content_part = ContentPart::ImageUrl(image_part);
789 OpenAIChatMessage {
790 role: "user".to_string(),
791 content: vec![image_content_part],
792 name: None,
793 }
794 }
795
796 fn create_openai_file_message() -> OpenAIChatMessage {
797 let file_part = FileContentPart::new(
798 Some("filedata".to_string()),
799 Some("fileid".to_string()),
800 Some("filename".to_string()),
801 );
802 let file_content_part = ContentPart::FileContent(file_part);
803 OpenAIChatMessage {
804 role: "user".to_string(),
805 content: vec![file_content_part],
806 name: None,
807 }
808 }
809
810 fn create_anthropic_text_message() -> MessageParam {
811 let text_block =
812 TextBlockParam::new_rs("What company is this logo from?".to_string(), None, None);
813 MessageParam {
814 role: "user".to_string(),
815 content: vec![ContentBlockParam {
816 inner: crate::anthropic::v1::request::ContentBlock::Text(text_block),
817 }],
818 }
819 }
820
821 fn create_anthropic_system_message() -> MessageParam {
822 let text_block = TextBlockParam::new_rs("system_prompt".to_string(), None, None);
823 MessageParam {
824 role: "assistant".to_string(),
825 content: vec![ContentBlockParam {
826 inner: crate::anthropic::v1::request::ContentBlock::Text(text_block),
827 }],
828 }
829 }
830
831 fn create_anthropic_base64_image_message() -> MessageParam {
832 let image_source =
833 Base64ImageSource::new("image/png".to_string(), "base64data".to_string()).unwrap();
834 let image_block = ImageBlockParam {
835 source: crate::anthropic::v1::request::ImageSource::Base64(image_source),
836 cache_control: None,
837 r#type: "image".to_string(),
838 };
839 MessageParam {
840 role: "user".to_string(),
841 content: vec![ContentBlockParam {
842 inner: crate::anthropic::v1::request::ContentBlock::Image(image_block),
843 }],
844 }
845 }
846
847 fn create_anthropic_url_image_message() -> MessageParam {
848 let image_source = UrlImageSource::new("https://iili.io/3Hs4FMg.png".to_string());
849 let image_block = ImageBlockParam {
850 source: crate::anthropic::v1::request::ImageSource::Url(image_source),
851 cache_control: None,
852 r#type: "image".to_string(),
853 };
854 MessageParam {
855 role: "user".to_string(),
856 content: vec![ContentBlockParam {
857 inner: crate::anthropic::v1::request::ContentBlock::Image(image_block),
858 }],
859 }
860 }
861
862 fn create_anthropic_base64_pdf_message() -> MessageParam {
863 let pdf_source = Base64PDFSource::new("base64pdfdata".to_string()).unwrap();
864 let document_block = DocumentBlockParam {
865 source: crate::anthropic::v1::request::DocumentSource::Base64(pdf_source),
866 cache_control: None,
867 title: Some("test_document.pdf".to_string()),
868 context: None,
869 r#type: "document".to_string(),
870 citations: None,
871 };
872 MessageParam {
873 role: "user".to_string(),
874 content: vec![ContentBlockParam {
875 inner: crate::anthropic::v1::request::ContentBlock::Document(document_block),
876 }],
877 }
878 }
879
880 fn create_anthropic_url_pdf_message() -> MessageParam {
881 let pdf_source = UrlPDFSource::new("https://example.com/document.pdf".to_string());
882 let document_block = DocumentBlockParam {
883 source: crate::anthropic::v1::request::DocumentSource::Url(pdf_source),
884 cache_control: None,
885 title: Some("test_document.pdf".to_string()),
886 context: None,
887 r#type: "document".to_string(),
888 citations: None,
889 };
890 MessageParam {
891 role: "user".to_string(),
892 content: vec![ContentBlockParam {
893 inner: crate::anthropic::v1::request::ContentBlock::Document(document_block),
894 }],
895 }
896 }
897
898 fn create_anthropic_plain_text_document_message() -> MessageParam {
899 let text_source = PlainTextSource::new("Plain text document content".to_string());
900 let document_block = DocumentBlockParam {
901 source: crate::anthropic::v1::request::DocumentSource::Text(text_source),
902 cache_control: None,
903 title: Some("text_document.txt".to_string()),
904 context: Some("Context for the document".to_string()),
905 r#type: "document".to_string(),
906 citations: None,
907 };
908 MessageParam {
909 role: "user".to_string(),
910 content: vec![ContentBlockParam {
911 inner: crate::anthropic::v1::request::ContentBlock::Document(document_block),
912 }],
913 }
914 }
915
916 #[test]
917 fn test_task_list_add_and_get() {
918 let text_part = TextContentPart::new("Test prompt. ${param1} ${param2}".to_string());
919 let content_part = ContentPart::Text(text_part);
920 let message = OpenAIChatMessage {
921 role: "user".to_string(),
922 content: vec![content_part],
923 name: None,
924 };
925
926 let prompt = Prompt::new_rs(
927 vec![MessageNum::OpenAIMessageV1(message)],
928 "gpt-4o",
929 Provider::OpenAI,
930 vec![],
931 None,
932 None,
933 ResponseType::Null,
934 )
935 .unwrap();
936
937 assert_eq!(prompt.request.messages().len(), 1);
939
940 assert!(prompt.parameters.len() == 2);
942
943 let mut parameters = prompt.parameters.clone();
945 parameters.sort();
946
947 assert_eq!(parameters[0], "param1");
948 assert_eq!(parameters[1], "param2");
949
950 let bound_msg = prompt.request.messages()[0]
952 .bind("param1", "Value1")
953 .unwrap();
954 let bound_msg = bound_msg.bind("param2", "Value2").unwrap();
955
956 match bound_msg.clone() {
958 MessageNum::OpenAIMessageV1(msg) => {
959 if let ContentPart::Text(text_part) = &msg.content[0] {
960 assert_eq!(text_part.text, "Test prompt. Value1 Value2");
961 } else {
962 panic!("Expected TextContentPart");
963 }
964 }
965 _ => panic!("Expected OpenAIMessageV1"),
966 }
967 }
968
969 #[test]
970 fn test_image_prompt() {
971 let text_message = create_openai_chat_message();
972 let image_message = create_openai_image_message();
973
974 let system_text_part = TextContentPart::new("system_prompt".to_string());
975 let system_text_content_part = ContentPart::Text(system_text_part);
976
977 let system_text_message = OpenAIChatMessage {
978 role: "assistant".to_string(),
979 content: vec![system_text_content_part],
980 name: None,
981 };
982
983 let prompt = Prompt::new_rs(
984 vec![
985 MessageNum::OpenAIMessageV1(text_message),
986 MessageNum::OpenAIMessageV1(image_message),
987 ],
988 "gpt-4o",
989 Provider::OpenAI,
990 vec![MessageNum::OpenAIMessageV1(system_text_message)],
991 None,
992 None,
993 ResponseType::Null,
994 )
995 .unwrap();
996
997 if let MessageNum::OpenAIMessageV1(msg) = &prompt.request.messages()[1] {
999 if let ContentPart::Text(text_part) = &msg.content[0] {
1000 assert_eq!(text_part.text, "What company is this logo from?");
1001 } else {
1002 panic!("Expected TextContentPart for the first user message");
1003 }
1004 } else {
1005 panic!("Expected OpenAIMessageV1 for the first user message");
1006 }
1007
1008 if let MessageNum::OpenAIMessageV1(msg) = &prompt.request.messages()[2] {
1010 if let ContentPart::ImageUrl(image_url) = &msg.content[0] {
1011 assert_eq!(image_url.image_url.url, "https://iili.io/3Hs4FMg.png");
1012 assert_eq!(image_url.r#type, "image_url");
1013 } else {
1014 panic!("Expected ContentPart::Image for the second user message");
1015 }
1016 } else {
1017 panic!("Expected OpenAIMessageV1 for the second user message");
1018 }
1019 }
1020
1021 #[test]
1022 fn test_document_prompt() {
1023 let text_message = create_openai_chat_message();
1024 let file_message = create_openai_file_message();
1025 let system_message = create_system_openai_chat_message();
1026
1027 let prompt = Prompt::new_rs(
1028 vec![
1029 MessageNum::OpenAIMessageV1(text_message),
1030 MessageNum::OpenAIMessageV1(file_message),
1031 ],
1032 "gpt-4o",
1033 Provider::OpenAI,
1034 vec![MessageNum::OpenAIMessageV1(system_message)],
1035 None,
1036 None,
1037 ResponseType::Null,
1038 )
1039 .unwrap();
1040
1041 if let MessageNum::OpenAIMessageV1(msg) = &prompt.request.messages()[2] {
1043 if let ContentPart::FileContent(file_content) = &msg.content[0] {
1044 assert_eq!(file_content.file.file_id.as_ref().unwrap(), "fileid");
1045 assert_eq!(file_content.file.filename.as_ref().unwrap(), "filename");
1046 } else {
1047 panic!("Expected ContentPart::FileContent for the second user message");
1048 }
1049 } else {
1050 panic!("Expected OpenAIMessageV1 for the first user message");
1051 }
1052 }
1053
1054 #[test]
1055 fn test_response_format_score() {
1056 let text_message = create_openai_chat_message();
1057 let prompt = Prompt::new_rs(
1058 vec![MessageNum::OpenAIMessageV1(text_message)],
1059 "gpt-4o",
1060 Provider::OpenAI,
1061 vec![],
1062 None,
1063 Some(Score::get_structured_output_schema()),
1064 ResponseType::Null,
1065 )
1066 .unwrap();
1067
1068 assert!(prompt.response_json_schema().is_some());
1070 }
1071
1072 #[test]
1073 fn test_anthropic_text_message_binding() {
1074 let text_block =
1075 TextBlockParam::new_rs("Test prompt. ${param1} ${param2}".to_string(), None, None);
1076 let message = MessageParam {
1077 role: "user".to_string(),
1078 content: vec![ContentBlockParam {
1079 inner: crate::anthropic::v1::request::ContentBlock::Text(text_block),
1080 }],
1081 };
1082
1083 let prompt = Prompt::new_rs(
1084 vec![MessageNum::AnthropicMessageV1(message)],
1085 "claude-3-5-sonnet-20241022",
1086 Provider::Anthropic,
1087 vec![],
1088 None,
1089 None,
1090 ResponseType::Null,
1091 )
1092 .unwrap();
1093
1094 assert_eq!(prompt.request.messages().len(), 1);
1095 assert_eq!(prompt.parameters.len(), 2);
1096
1097 let mut parameters = prompt.parameters.clone();
1098 parameters.sort();
1099 assert_eq!(parameters[0], "param1");
1100 assert_eq!(parameters[1], "param2");
1101
1102 let bound_msg = prompt.request.messages()[0]
1104 .bind("param1", "Value1")
1105 .unwrap();
1106 let bound_msg = bound_msg.bind("param2", "Value2").unwrap();
1107
1108 match bound_msg {
1109 MessageNum::AnthropicMessageV1(msg) => {
1110 if let crate::anthropic::v1::request::ContentBlock::Text(text_block) =
1111 &msg.content[0].inner
1112 {
1113 assert_eq!(text_block.text, "Test prompt. Value1 Value2");
1114 } else {
1115 panic!("Expected TextBlockParam");
1116 }
1117 }
1118 _ => panic!("Expected AnthropicMessageV1"),
1119 }
1120 }
1121
1122 #[test]
1123 fn test_anthropic_url_image_prompt() {
1124 let text_message = create_anthropic_text_message();
1125 let image_message = create_anthropic_url_image_message();
1126 let system_message = create_anthropic_system_message();
1127
1128 let prompt = Prompt::new_rs(
1129 vec![
1130 MessageNum::AnthropicMessageV1(text_message),
1131 MessageNum::AnthropicMessageV1(image_message),
1132 ],
1133 "claude-3-5-sonnet-20241022",
1134 Provider::Anthropic,
1135 vec![MessageNum::AnthropicMessageV1(system_message)],
1136 None,
1137 None,
1138 ResponseType::Null,
1139 )
1140 .unwrap();
1141
1142 if let MessageNum::AnthropicMessageV1(msg) = &prompt.request.messages()[0] {
1144 if let crate::anthropic::v1::request::ContentBlock::Text(text_block) =
1145 &msg.content[0].inner
1146 {
1147 assert_eq!(text_block.text, "What company is this logo from?");
1148 } else {
1149 panic!("Expected TextBlock for first message");
1150 }
1151 } else {
1152 panic!("Expected AnthropicMessageV1");
1153 }
1154
1155 if let MessageNum::AnthropicMessageV1(msg) = &prompt.request.messages()[1] {
1157 if let crate::anthropic::v1::request::ContentBlock::Image(image_block) =
1158 &msg.content[0].inner
1159 {
1160 match &image_block.source {
1161 crate::anthropic::v1::request::ImageSource::Url(url_source) => {
1162 assert_eq!(url_source.url, "https://iili.io/3Hs4FMg.png");
1163 assert_eq!(url_source.r#type, "url");
1164 }
1165 _ => panic!("Expected URL image source"),
1166 }
1167 assert_eq!(image_block.r#type, "image");
1168 } else {
1169 panic!("Expected ImageBlock for second message");
1170 }
1171 } else {
1172 panic!("Expected AnthropicMessageV1");
1173 }
1174 }
1175
1176 #[test]
1177 fn test_anthropic_base64_image_prompt() {
1178 let text_message = create_anthropic_text_message();
1179 let image_message = create_anthropic_base64_image_message();
1180
1181 let prompt = Prompt::new_rs(
1182 vec![
1183 MessageNum::AnthropicMessageV1(text_message),
1184 MessageNum::AnthropicMessageV1(image_message),
1185 ],
1186 "claude-3-5-sonnet-20241022",
1187 Provider::Anthropic,
1188 vec![],
1189 None,
1190 None,
1191 ResponseType::Null,
1192 )
1193 .unwrap();
1194
1195 if let MessageNum::AnthropicMessageV1(msg) = &prompt.request.messages()[1] {
1197 if let crate::anthropic::v1::request::ContentBlock::Image(image_block) =
1198 &msg.content[0].inner
1199 {
1200 match &image_block.source {
1201 crate::anthropic::v1::request::ImageSource::Base64(base64_source) => {
1202 assert_eq!(base64_source.media_type, "image/png");
1203 assert_eq!(base64_source.data, "base64data");
1204 assert_eq!(base64_source.r#type, "base64");
1205 }
1206 _ => panic!("Expected Base64 image source"),
1207 }
1208 } else {
1209 panic!("Expected ImageBlock");
1210 }
1211 } else {
1212 panic!("Expected AnthropicMessageV1");
1213 }
1214 }
1215
1216 #[test]
1218 fn test_anthropic_base64_pdf_document_prompt() {
1219 let text_message = create_anthropic_text_message();
1220 let pdf_message = create_anthropic_base64_pdf_message();
1221 let system_message = create_anthropic_system_message();
1222
1223 let prompt = Prompt::new_rs(
1224 vec![
1225 MessageNum::AnthropicMessageV1(text_message),
1226 MessageNum::AnthropicMessageV1(pdf_message),
1227 ],
1228 "claude-3-5-sonnet-20241022",
1229 Provider::Anthropic,
1230 vec![MessageNum::AnthropicMessageV1(system_message)],
1231 None,
1232 None,
1233 ResponseType::Null,
1234 )
1235 .unwrap();
1236
1237 if let MessageNum::AnthropicMessageV1(msg) = &prompt.request.messages()[1] {
1239 if let crate::anthropic::v1::request::ContentBlock::Document(document_block) =
1240 &msg.content[0].inner
1241 {
1242 match &document_block.source {
1243 crate::anthropic::v1::request::DocumentSource::Base64(pdf_source) => {
1244 assert_eq!(pdf_source.media_type, "application/pdf");
1245 assert_eq!(pdf_source.data, "base64pdfdata");
1246 assert_eq!(pdf_source.r#type, "base64");
1247 }
1248 _ => panic!("Expected Base64 PDF source"),
1249 }
1250 assert_eq!(document_block.r#type, "document");
1251 assert_eq!(document_block.title.as_ref().unwrap(), "test_document.pdf");
1252 } else {
1253 panic!("Expected DocumentBlock");
1254 }
1255 } else {
1256 panic!("Expected AnthropicMessageV1");
1257 }
1258 }
1259
1260 #[test]
1262 fn test_anthropic_url_pdf_document_prompt() {
1263 let text_message = create_anthropic_text_message();
1264 let pdf_message = create_anthropic_url_pdf_message();
1265
1266 let prompt = Prompt::new_rs(
1267 vec![
1268 MessageNum::AnthropicMessageV1(text_message),
1269 MessageNum::AnthropicMessageV1(pdf_message),
1270 ],
1271 "claude-3-5-sonnet-20241022",
1272 Provider::Anthropic,
1273 vec![],
1274 None,
1275 None,
1276 ResponseType::Null,
1277 )
1278 .unwrap();
1279
1280 if let MessageNum::AnthropicMessageV1(msg) = &prompt.request.messages()[1] {
1282 if let crate::anthropic::v1::request::ContentBlock::Document(document_block) =
1283 &msg.content[0].inner
1284 {
1285 match &document_block.source {
1286 crate::anthropic::v1::request::DocumentSource::Url(url_source) => {
1287 assert_eq!(url_source.url, "https://example.com/document.pdf");
1288 assert_eq!(url_source.r#type, "url");
1289 }
1290 _ => panic!("Expected URL PDF source"),
1291 }
1292 } else {
1293 panic!("Expected DocumentBlock");
1294 }
1295 } else {
1296 panic!("Expected AnthropicMessageV1");
1297 }
1298 }
1299
1300 #[test]
1302 fn test_anthropic_plain_text_document_prompt() {
1303 let text_message = create_anthropic_text_message();
1304 let text_doc_message = create_anthropic_plain_text_document_message();
1305
1306 let prompt = Prompt::new_rs(
1307 vec![
1308 MessageNum::AnthropicMessageV1(text_message),
1309 MessageNum::AnthropicMessageV1(text_doc_message),
1310 ],
1311 "claude-3-5-sonnet-20241022",
1312 Provider::Anthropic,
1313 vec![],
1314 None,
1315 None,
1316 ResponseType::Null,
1317 )
1318 .unwrap();
1319
1320 if let MessageNum::AnthropicMessageV1(msg) = &prompt.request.messages()[1] {
1322 if let crate::anthropic::v1::request::ContentBlock::Document(document_block) =
1323 &msg.content[0].inner
1324 {
1325 match &document_block.source {
1326 crate::anthropic::v1::request::DocumentSource::Text(text_source) => {
1327 assert_eq!(text_source.media_type, "text/plain");
1328 assert_eq!(text_source.data, "Plain text document content");
1329 assert_eq!(text_source.r#type, "text");
1330 }
1331 _ => panic!("Expected Text document source"),
1332 }
1333 assert_eq!(
1334 document_block.context.as_ref().unwrap(),
1335 "Context for the document"
1336 );
1337 } else {
1338 panic!("Expected DocumentBlock");
1339 }
1340 } else {
1341 panic!("Expected AnthropicMessageV1");
1342 }
1343 }
1344
1345 #[test]
1347 fn test_anthropic_mixed_content_prompt() {
1348 let text_message = create_anthropic_text_message();
1349 let pdf_message = create_anthropic_base64_pdf_message();
1350 let text_doc_message = create_anthropic_plain_text_document_message();
1351 let system_message = create_anthropic_system_message();
1352
1353 let prompt = Prompt::new_rs(
1354 vec![
1355 MessageNum::AnthropicMessageV1(text_message),
1356 MessageNum::AnthropicMessageV1(pdf_message),
1357 MessageNum::AnthropicMessageV1(text_doc_message),
1358 ],
1359 "claude-3-5-sonnet-20241022",
1360 Provider::Anthropic,
1361 vec![MessageNum::AnthropicMessageV1(system_message)],
1362 None,
1363 None,
1364 ResponseType::Null,
1365 )
1366 .unwrap();
1367
1368 assert_eq!(prompt.request.messages().len(), 3);
1369 assert_eq!(prompt.request.system_instructions().len(), 1);
1370 assert_eq!(prompt.provider, Provider::Anthropic);
1371 assert_eq!(prompt.model, "claude-3-5-sonnet-20241022");
1372 }
1373
1374 #[test]
1376 fn test_gemini_chat_message() {
1377 let text = Part::from_text("Test prompt. ${param1} ${param2}".to_string());
1378 let message = GeminiContent {
1379 role: "user".to_string(),
1380 parts: vec![text],
1381 };
1382
1383 let prompt = Prompt::new_rs(
1384 vec![MessageNum::GeminiContentV1(message)],
1385 "gemini-1.5-pro",
1386 Provider::Google,
1387 vec![],
1388 None,
1389 None,
1390 ResponseType::Null,
1391 )
1392 .unwrap();
1393
1394 assert_eq!(prompt.request.messages().len(), 1);
1395 assert_eq!(prompt.parameters.len(), 2);
1396
1397 let mut parameters = prompt.parameters.clone();
1398 parameters.sort();
1399 assert_eq!(parameters[0], "param1");
1400 assert_eq!(parameters[1], "param2");
1401
1402 let bound_msg = prompt.request.messages()[0]
1404 .bind("param1", "Value1")
1405 .unwrap();
1406 let bound_msg = bound_msg.bind("param2", "Value2").unwrap();
1407
1408 match bound_msg {
1409 MessageNum::GeminiContentV1(msg) => {
1410 if let DataNum::Text(text_part) = &msg.parts[0].data {
1411 assert_eq!(text_part, "Test prompt. Value1 Value2");
1412 } else {
1413 panic!("Expected Text Part");
1414 }
1415 }
1416 _ => panic!("Expected GeminiContentV1"),
1417 }
1418 }
1419}