1use crate::anthropic::v1::request::{AnthropicSettings, MessageParam as AnthropicMessage};
2use crate::error::TypeError;
3use crate::google::v1::generate::request::{GeminiContent, GeminiSettings};
4use crate::openai::v1::chat::request::ChatMessage as OpenAIChatMessage;
5use crate::openai::v1::chat::settings::OpenAIChatSettings;
6use crate::prompt::builder::{to_provider_request, ProviderRequest};
7use crate::prompt::settings::ModelSettings;
8use crate::prompt::types::parse_response_to_json;
9use crate::prompt::types::ResponseType;
10use crate::prompt::types::Role;
11use crate::prompt::{AnthropicMessageList, GeminiContentList, MessageNum, OpenAIMessageList};
12use crate::tools::AgentToolDefinition;
13use crate::traits::MessageFactory;
14use crate::SettingsType;
15use crate::{Provider, SaveName};
16use potato_util::utils::extract_string_value;
17use potato_util::PyHelperFuncs;
18use potatohead_macro::try_extract_message;
19use pyo3::prelude::*;
20use pyo3::types::{PyDict, PyList, PyString, PyTuple};
21use pythonize::pythonize;
22use serde::{Deserialize, Serialize};
23use serde_json::Value;
24use std::collections::HashSet;
25use std::path::PathBuf;
26
27fn create_message_for_provider(
28 content: String,
29 provider: &Provider,
30 role: &str,
31) -> Result<MessageNum, TypeError> {
32 match provider {
33 Provider::OpenAI => {
34 OpenAIChatMessage::from_text(content, role).map(MessageNum::OpenAIMessageV1)
35 }
36 Provider::Anthropic => {
37 AnthropicMessage::from_text(content, role).map(MessageNum::AnthropicMessageV1)
38 }
39 Provider::Gemini | Provider::Google | Provider::Vertex => {
40 GeminiContent::from_text(content, role).map(MessageNum::GeminiContentV1)
41 }
42 _ => Err(TypeError::Error(format!(
43 "Unsupported provider for message creation: {:?}",
44 provider
45 ))),
46 }
47}
48
49fn parse_single_message(
50 message: &Bound<'_, PyAny>,
51 provider: &Provider,
52 default_role: &str,
53) -> Result<MessageNum, TypeError> {
54 if message.is_instance_of::<PyString>() {
56 let text = message.extract::<String>()?;
57 return create_message_for_provider(text, provider, default_role);
58 }
59
60 try_extract_message!(
62 message,
63 OpenAIChatMessage => MessageNum::OpenAIMessageV1,
64 AnthropicMessage => MessageNum::AnthropicMessageV1,
65 GeminiContent => MessageNum::GeminiContentV1,
66 );
67
68 Err(TypeError::InvalidMessageTypeInList(
69 message.get_type().name()?.to_string(),
70 ))
71}
72
73fn parse_messages(
74 messages: &Bound<'_, PyAny>,
75 provider: &Provider,
76 default_role: &str,
77) -> Result<Vec<MessageNum>, TypeError> {
78 let mut messages =
80 if !messages.is_instance_of::<PyList>() && !messages.is_instance_of::<PyTuple>() {
81 vec![parse_single_message(messages, provider, default_role)?]
82 } else {
83 messages
85 .try_iter()?
86 .map(|item| {
87 let item = item?;
88 parse_single_message(&item, provider, default_role)
89 })
90 .collect::<Result<Vec<_>, _>>()?
91 };
92
93 if provider == &Provider::Anthropic
96 && (default_role == Role::System.as_str()
97 || default_role == Role::Assistant.as_str()
98 || default_role == Role::Developer.as_str())
99 {
100 for msg in messages.iter_mut() {
101 msg.anthropic_message_to_system_message()?;
102 }
103 }
104
105 Ok(messages)
106}
107
108fn get_system_role(provider: &Provider) -> &'static str {
109 match provider {
110 Provider::OpenAI => Role::Developer.into(),
111 Provider::Gemini | Provider::Vertex | Provider::Google => Role::Model.into(),
112 Provider::Anthropic => Role::System.into(),
113 _ => Role::System.into(),
114 }
115}
116
117pub fn extract_system_instructions(
119 system_instruction: Option<&Bound<'_, PyAny>>,
120 provider: &Provider,
121) -> Result<Option<Vec<MessageNum>>, TypeError> {
122 let system_instructions = if let Some(sys_inst) = system_instruction {
123 Some(parse_messages(
124 sys_inst,
125 provider,
126 get_system_role(provider),
127 )?)
128 } else {
129 None
130 };
131
132 Ok(system_instructions)
133}
134
135#[pyclass]
136#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
137pub struct Prompt {
138 pub request: ProviderRequest,
139
140 #[pyo3(get)]
141 pub model: String,
142
143 #[pyo3(get)]
144 pub provider: Provider,
145
146 pub version: String,
147
148 #[pyo3(get)]
149 #[serde(default)]
150 pub parameters: Vec<String>,
151
152 #[serde(default)]
153 pub response_type: ResponseType,
154}
155
156fn extract_model_settings(model_settings: &Bound<'_, PyAny>) -> Result<ModelSettings, TypeError> {
158 let settings_type = model_settings
159 .call_method0("settings_type")?
160 .extract::<SettingsType>()?;
161
162 match settings_type {
163 SettingsType::OpenAIChat => model_settings
164 .extract::<OpenAIChatSettings>()
165 .map(ModelSettings::OpenAIChat),
166 SettingsType::GoogleChat => model_settings
167 .extract::<GeminiSettings>()
168 .map(ModelSettings::GoogleChat),
169 SettingsType::Anthropic => model_settings
170 .extract::<AnthropicSettings>()
171 .map(ModelSettings::AnthropicChat),
172 SettingsType::ModelSettings => model_settings.extract::<ModelSettings>(),
173 }
174 .map_err(Into::into)
175}
176
177#[pymethods]
178impl Prompt {
179 #[new]
195 #[pyo3(signature = (messages, model, provider, system_instructions=None, model_settings=None, output_type=None))]
196 pub fn new(
197 py: Python<'_>,
198 messages: &Bound<'_, PyAny>,
199 model: &str,
200 provider: &Bound<'_, PyAny>,
201 system_instructions: Option<&Bound<'_, PyAny>>,
202 model_settings: Option<&Bound<'_, PyAny>>,
203 output_type: Option<&Bound<'_, PyAny>>, ) -> Result<Self, TypeError> {
205 let model_settings = model_settings
207 .as_ref()
208 .map(|s| extract_model_settings(s))
209 .transpose()?;
210
211 let provider = Provider::extract_provider(provider)?;
213
214 let messages = parse_messages(messages, &provider, Role::User.into())?;
217 let system_instructions = if let Some(sys_inst) = system_instructions {
218 parse_messages(sys_inst, &provider, get_system_role(&provider))?
219 } else {
220 vec![]
221 };
222
223 let (response_type, response_json_schema) = match output_type {
225 Some(output_type) => {
226 parse_response_to_json(py, output_type)?
228 }
229 None => (ResponseType::Null, None),
230 };
231
232 Self::new_rs(
233 messages,
234 model,
235 provider,
236 system_instructions,
237 model_settings,
238 response_json_schema,
239 response_type,
240 )
241 }
242
243 #[getter]
244 pub fn model_settings<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyAny>, TypeError> {
245 self.request.model_settings(py)
246 }
247
248 #[getter]
249 pub fn model_identifier(&self) -> String {
250 format!("{}:{}", self.provider.as_str(), self.model)
251 }
252
253 #[pyo3(signature = (path = None))]
254 pub fn save_prompt(&self, path: Option<PathBuf>) -> PyResult<PathBuf> {
255 let save_path = path.unwrap_or_else(|| PathBuf::from(SaveName::Prompt));
256 PyHelperFuncs::save_to_json(self, &save_path)?;
257 Ok(save_path)
258 }
259
260 #[staticmethod]
261 pub fn from_path(path: PathBuf) -> Result<Self, TypeError> {
262 let content = std::fs::read_to_string(&path)?;
263
264 let extension = path
265 .extension()
266 .and_then(|ext| ext.to_str())
267 .ok_or_else(|| TypeError::Error(format!("Invalid file path: {:?}", path)))?;
268
269 let mut prompt = match extension.to_lowercase().as_str() {
270 "json" => {
271 let json_value: Prompt = serde_json::from_str(&content)?;
272 Ok(json_value)
273 }
274 "yaml" | "yml" => {
275 let yaml_value: Prompt = serde_yaml::from_str(&content)?;
276 Ok(yaml_value)
277 }
278 _ => Err(TypeError::Error(format!(
279 "Unsupported file extension '{}'. Expected .json, .yaml, or .yml",
280 extension
281 ))),
282 }?;
283
284 if prompt.parameters.is_empty() {
285 let system_instructions: Vec<MessageNum> = prompt
287 .request
288 .system_instructions()
289 .iter()
290 .map(|msg| (*msg).clone())
291 .collect();
292 let parameters =
293 Self::extract_variables(prompt.request.messages(), &system_instructions);
294 prompt.parameters = parameters;
295 }
296
297 Ok(prompt)
298 }
299
300 #[staticmethod]
301 pub fn model_validate_json(json_string: String) -> Result<Self, TypeError> {
302 let json_value: Value = serde_json::from_str(&json_string)?;
303 let model: Self = serde_json::from_value(json_value)?;
304
305 Ok(model)
306 }
307
308 pub fn model_dump_json(&self) -> String {
309 serde_json::to_string(self).unwrap()
310 }
311
312 pub fn __str__(&self) -> String {
313 PyHelperFuncs::__str__(self)
314 }
315
316 #[getter]
317 pub fn all_messages<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyList>, TypeError> {
319 self.request.get_all_py_messages(py)
320 }
321
322 #[getter]
323 pub fn messages<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyList>, TypeError> {
325 self.request.get_py_messages(py)
326 }
327
328 #[getter]
329 pub fn message<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyAny>, TypeError> {
331 self.request.get_py_message(py)
332 }
333
334 #[getter]
335 pub fn openai_messages(&self) -> Result<OpenAIMessageList, TypeError> {
338 if self.provider != Provider::OpenAI {
339 return Err(TypeError::Error(
340 "Prompt provider is not OpenAI".to_string(),
341 ));
342 }
343 let messages = self
344 .request
345 .messages()
346 .iter()
347 .filter(|msg| msg.is_user_message())
348 .filter_map(|msg| match msg {
349 MessageNum::OpenAIMessageV1(m) => Some(m.clone()),
350 _ => None,
351 })
352 .collect::<Vec<_>>();
353 Ok(OpenAIMessageList { messages })
354 }
355
356 #[getter]
357 pub fn openai_message(&self) -> Result<OpenAIChatMessage, TypeError> {
360 if self.provider != Provider::OpenAI {
361 return Err(TypeError::Error(
362 "Prompt provider is not OpenAI".to_string(),
363 ));
364 }
365 self.request.get_openai_message()
366 }
367
368 #[getter]
369 pub fn gemini_messages(&self) -> Result<GeminiContentList, TypeError> {
372 if !self.is_google_provider() {
373 return Err(TypeError::Error(
374 "Prompt provider is not Google, Gemini, or Vertex".to_string(),
375 ));
376 }
377 let messages = self
378 .request
379 .messages()
380 .iter()
381 .filter(|msg| msg.is_user_message())
382 .filter_map(|msg| match msg {
383 MessageNum::GeminiContentV1(m) => Some(m.clone()),
384 _ => None,
385 })
386 .collect::<Vec<_>>();
387
388 Ok(GeminiContentList { messages })
389 }
390
391 #[getter]
392 pub fn gemini_message(&self) -> Result<GeminiContent, TypeError> {
395 if !self.is_google_provider() {
396 return Err(TypeError::Error(
397 "Prompt provider is not Google, Gemini, or Vertex".to_string(),
398 ));
399 }
400 self.request.get_gemini_message()
401 }
402
403 #[getter]
404 pub fn anthropic_messages(&self) -> Result<AnthropicMessageList, TypeError> {
405 if self.provider != Provider::Anthropic {
406 return Err(TypeError::Error(
407 "Prompt provider is not Anthropic".to_string(),
408 ));
409 }
410 let messages = self
411 .request
412 .messages()
413 .iter()
414 .filter(|msg| msg.is_user_message())
415 .filter_map(|msg| match msg {
416 MessageNum::AnthropicMessageV1(m) => Some(m.clone()),
417 _ => None,
418 })
419 .collect::<Vec<_>>();
420
421 Ok(AnthropicMessageList { messages })
422 }
423
424 #[getter]
425 pub fn anthropic_message(&self) -> Result<AnthropicMessage, TypeError> {
427 if self.provider != Provider::Anthropic {
428 return Err(TypeError::Error(
429 "Prompt provider is not Anthropic".to_string(),
430 ));
431 }
432 self.request.get_anthropic_message()
433 }
434
435 #[getter]
436 pub fn system_instructions<'py>(
437 &self,
438 py: Python<'py>,
439 ) -> Result<Bound<'py, PyList>, TypeError> {
440 self.request.get_py_system_instructions(py)
441 }
442
443 #[pyo3(signature = (name=None, value=None, **kwargs))]
451 pub fn bind(
452 &self,
453 name: Option<&str>,
454 value: Option<&Bound<'_, PyAny>>,
455 kwargs: Option<&Bound<'_, PyDict>>,
456 ) -> Result<Self, TypeError> {
457 let mut new_prompt = self.clone();
458
459 if let (Some(name), Some(value)) = (name, value) {
460 let var_value = extract_string_value(value)?;
461 for message in new_prompt.request.messages_mut() {
462 message.bind_mut(name, &var_value)?;
463 }
464 }
465
466 if let Some(kwargs) = kwargs {
467 for (key, val) in kwargs.iter() {
468 let var_name = key.extract::<String>()?;
469 let var_value = extract_string_value(&val)?;
470
471 for message in new_prompt.request.messages_mut() {
472 message.bind_mut(&var_name, &var_value)?;
473 }
474 }
475 }
476
477 if name.is_none() && kwargs.is_none_or(|k| k.is_empty()) {
478 return Err(TypeError::Error(
479 "Must provide either (name, value) or keyword arguments for binding".to_string(),
480 ));
481 }
482
483 Ok(new_prompt)
484 }
485
486 #[pyo3(signature = (name=None, value=None, **kwargs))]
493 pub fn bind_mut(
494 &mut self,
495 name: Option<&str>,
496 value: Option<&Bound<'_, PyAny>>,
497 kwargs: Option<&Bound<'_, PyDict>>,
498 ) -> Result<(), TypeError> {
499 if let (Some(name), Some(value)) = (name, value) {
500 let var_value = extract_string_value(value)?;
501 for message in self.request.messages_mut() {
502 message.bind_mut(name, &var_value)?;
503 }
504 }
505
506 if let Some(kwargs) = kwargs {
507 for (key, val) in kwargs.iter() {
508 let var_name = key.extract::<String>()?;
509 let var_value = extract_string_value(&val)?;
510
511 for message in self.request.messages_mut() {
512 message.bind_mut(&var_name, &var_value)?;
513 }
514 }
515 }
516
517 if name.is_none() && kwargs.is_none_or(|k| k.is_empty()) {
518 return Err(TypeError::Error(
519 "Must provide either (name, value) or keyword arguments for binding".to_string(),
520 ));
521 }
522
523 Ok(())
524 }
525
526 #[getter]
527 pub fn response_json_schema_pretty(&self) -> Option<String> {
528 Some(PyHelperFuncs::__str__(
529 self.request.response_json_schema().as_ref()?,
530 ))
531 }
532
533 #[getter]
534 #[pyo3(name = "response_json_schema")]
535 pub fn response_json_schema_py(&self) -> Option<String> {
536 Some(self.request.response_json_schema().as_ref()?.to_string())
537 }
538
539 pub fn model_dump<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyAny>, TypeError> {
540 let request = &self.request.to_json()?;
541 Ok(pythonize(py, request)?)
542 }
543}
544
545impl Prompt {
546 pub fn response_json_schema(&self) -> Option<&Value> {
547 self.request.response_json_schema()
548 }
549
550 pub fn new_rs(
551 messages: Vec<MessageNum>,
552 model: &str,
553 provider: Provider,
554 system_instructions: Vec<MessageNum>,
555 model_settings: Option<ModelSettings>,
556 response_json_schema: Option<Value>,
557 response_type: ResponseType,
558 ) -> Result<Self, TypeError> {
559 let model = model.to_string();
560 let version = potato_util::version();
562 let model_settings = match model_settings {
564 Some(settings) => {
565 settings.validate_provider(&provider)?;
567 settings
568 }
569 None => ModelSettings::provider_default_settings(&provider),
570 };
571
572 let parameters = Self::extract_variables(&messages, &system_instructions);
574
575 let request = to_provider_request(
577 messages,
578 system_instructions,
579 model.clone(),
580 model_settings,
581 response_json_schema,
582 )?;
583
584 Ok(Self {
585 request,
586 version,
587 parameters,
588 response_type,
589 model,
590 provider,
591 })
592 }
593
594 fn is_google_provider(&self) -> bool {
595 matches!(
596 self.provider,
597 Provider::Google | Provider::Gemini | Provider::Vertex
598 )
599 }
600
601 pub fn add_tools(&mut self, tools: Vec<AgentToolDefinition>) -> Result<(), TypeError> {
602 self.request.add_tools(tools)
603 }
604
605 pub fn extract_variables(
606 messages: &[MessageNum],
607 system_instructions: &[MessageNum],
608 ) -> Vec<String> {
609 let mut variables = HashSet::new();
610
611 for msg in system_instructions {
613 variables.extend(msg.extract_variables());
614 }
615
616 for msg in messages {
618 variables.extend(msg.extract_variables());
619 }
620
621 variables.into_iter().collect()
622 }
623
624 pub fn model_dump_value(&self) -> Value {
625 serde_json::to_value(self).unwrap_or(Value::Null)
627 }
628
629 pub fn to_request_json(&self) -> Result<Value, TypeError> {
630 let json_value = serde_json::to_value(self)?;
632
633 Ok(json_value)
634 }
635
636 pub fn set_response_json_schema(
637 &mut self,
638 response_json_schema: Option<Value>,
639 response_type: ResponseType,
640 ) {
641 self.request.set_response_json_schema(response_json_schema);
642 self.response_type = response_type;
643 }
644}
645
646#[cfg(test)]
648mod tests {
649 use super::*;
650 use crate::anthropic::v1::request::{
651 Base64ImageSource, Base64PDFSource, ContentBlockParam, DocumentBlockParam, ImageBlockParam,
652 MessageParam, PlainTextSource, TextBlockParam, UrlImageSource, UrlPDFSource,
653 };
654 use crate::google::{DataNum, GeminiContent, Part};
655 use crate::openai::v1::chat::request::{
656 ChatMessage as OpenAIChatMessage, ContentPart, FileContentPart, ImageContentPart,
657 TextContentPart,
658 };
659 use crate::prompt::types::Score;
660 use crate::StructuredOutput;
661
662 fn create_openai_chat_message() -> OpenAIChatMessage {
663 let text_part = TextContentPart::new("What company is this logo from?".to_string());
664 let text_content_part = ContentPart::Text(text_part);
665 OpenAIChatMessage {
666 role: "user".to_string(),
667 content: vec![text_content_part],
668 name: None,
669 }
670 }
671
672 fn create_system_openai_chat_message() -> OpenAIChatMessage {
673 let text_part = TextContentPart::new("system_prompt".to_string());
674 let text_content_part = ContentPart::Text(text_part);
675 OpenAIChatMessage {
676 role: "developer".to_string(),
677 content: vec![text_content_part],
678 name: None,
679 }
680 }
681
682 fn create_openai_image_message() -> OpenAIChatMessage {
683 let image_part = ImageContentPart::new("https://iili.io/3Hs4FMg.png".to_string(), None);
684 let image_content_part = ContentPart::ImageUrl(image_part);
685 OpenAIChatMessage {
686 role: "user".to_string(),
687 content: vec![image_content_part],
688 name: None,
689 }
690 }
691
692 fn create_openai_file_message() -> OpenAIChatMessage {
693 let file_part = FileContentPart::new(
694 Some("filedata".to_string()),
695 Some("fileid".to_string()),
696 Some("filename".to_string()),
697 );
698 let file_content_part = ContentPart::FileContent(file_part);
699 OpenAIChatMessage {
700 role: "user".to_string(),
701 content: vec![file_content_part],
702 name: None,
703 }
704 }
705
706 fn create_anthropic_text_message() -> MessageParam {
707 let text_block =
708 TextBlockParam::new_rs("What company is this logo from?".to_string(), None, None);
709 MessageParam {
710 role: "user".to_string(),
711 content: vec![ContentBlockParam {
712 inner: crate::anthropic::v1::request::ContentBlock::Text(text_block),
713 }],
714 }
715 }
716
717 fn create_anthropic_system_message() -> MessageParam {
718 let text_block = TextBlockParam::new_rs("system_prompt".to_string(), None, None);
719 MessageParam {
720 role: "assistant".to_string(),
721 content: vec![ContentBlockParam {
722 inner: crate::anthropic::v1::request::ContentBlock::Text(text_block),
723 }],
724 }
725 }
726
727 fn create_anthropic_base64_image_message() -> MessageParam {
728 let image_source =
729 Base64ImageSource::new("image/png".to_string(), "base64data".to_string()).unwrap();
730 let image_block = ImageBlockParam {
731 source: crate::anthropic::v1::request::ImageSource::Base64(image_source),
732 cache_control: None,
733 r#type: "image".to_string(),
734 };
735 MessageParam {
736 role: "user".to_string(),
737 content: vec![ContentBlockParam {
738 inner: crate::anthropic::v1::request::ContentBlock::Image(image_block),
739 }],
740 }
741 }
742
743 fn create_anthropic_url_image_message() -> MessageParam {
744 let image_source = UrlImageSource::new("https://iili.io/3Hs4FMg.png".to_string());
745 let image_block = ImageBlockParam {
746 source: crate::anthropic::v1::request::ImageSource::Url(image_source),
747 cache_control: None,
748 r#type: "image".to_string(),
749 };
750 MessageParam {
751 role: "user".to_string(),
752 content: vec![ContentBlockParam {
753 inner: crate::anthropic::v1::request::ContentBlock::Image(image_block),
754 }],
755 }
756 }
757
758 fn create_anthropic_base64_pdf_message() -> MessageParam {
759 let pdf_source = Base64PDFSource::new("base64pdfdata".to_string()).unwrap();
760 let document_block = DocumentBlockParam {
761 source: crate::anthropic::v1::request::DocumentSource::Base64(pdf_source),
762 cache_control: None,
763 title: Some("test_document.pdf".to_string()),
764 context: None,
765 r#type: "document".to_string(),
766 citations: None,
767 };
768 MessageParam {
769 role: "user".to_string(),
770 content: vec![ContentBlockParam {
771 inner: crate::anthropic::v1::request::ContentBlock::Document(document_block),
772 }],
773 }
774 }
775
776 fn create_anthropic_url_pdf_message() -> MessageParam {
777 let pdf_source = UrlPDFSource::new("https://example.com/document.pdf".to_string());
778 let document_block = DocumentBlockParam {
779 source: crate::anthropic::v1::request::DocumentSource::Url(pdf_source),
780 cache_control: None,
781 title: Some("test_document.pdf".to_string()),
782 context: None,
783 r#type: "document".to_string(),
784 citations: None,
785 };
786 MessageParam {
787 role: "user".to_string(),
788 content: vec![ContentBlockParam {
789 inner: crate::anthropic::v1::request::ContentBlock::Document(document_block),
790 }],
791 }
792 }
793
794 fn create_anthropic_plain_text_document_message() -> MessageParam {
795 let text_source = PlainTextSource::new("Plain text document content".to_string());
796 let document_block = DocumentBlockParam {
797 source: crate::anthropic::v1::request::DocumentSource::Text(text_source),
798 cache_control: None,
799 title: Some("text_document.txt".to_string()),
800 context: Some("Context for the document".to_string()),
801 r#type: "document".to_string(),
802 citations: None,
803 };
804 MessageParam {
805 role: "user".to_string(),
806 content: vec![ContentBlockParam {
807 inner: crate::anthropic::v1::request::ContentBlock::Document(document_block),
808 }],
809 }
810 }
811
812 #[test]
813 fn test_task_list_add_and_get() {
814 let text_part = TextContentPart::new("Test prompt. ${param1} ${param2}".to_string());
815 let content_part = ContentPart::Text(text_part);
816 let message = OpenAIChatMessage {
817 role: "user".to_string(),
818 content: vec![content_part],
819 name: None,
820 };
821
822 let prompt = Prompt::new_rs(
823 vec![MessageNum::OpenAIMessageV1(message)],
824 "gpt-4o",
825 Provider::OpenAI,
826 vec![],
827 None,
828 None,
829 ResponseType::Null,
830 )
831 .unwrap();
832
833 assert_eq!(prompt.request.messages().len(), 1);
835
836 assert!(prompt.parameters.len() == 2);
838
839 let mut parameters = prompt.parameters.clone();
841 parameters.sort();
842
843 assert_eq!(parameters[0], "param1");
844 assert_eq!(parameters[1], "param2");
845
846 let bound_msg = prompt.request.messages()[0]
848 .bind("param1", "Value1")
849 .unwrap();
850 let bound_msg = bound_msg.bind("param2", "Value2").unwrap();
851
852 match bound_msg.clone() {
854 MessageNum::OpenAIMessageV1(msg) => {
855 if let ContentPart::Text(text_part) = &msg.content[0] {
856 assert_eq!(text_part.text, "Test prompt. Value1 Value2");
857 } else {
858 panic!("Expected TextContentPart");
859 }
860 }
861 _ => panic!("Expected OpenAIMessageV1"),
862 }
863 }
864
865 #[test]
866 fn test_image_prompt() {
867 let text_message = create_openai_chat_message();
868 let image_message = create_openai_image_message();
869
870 let system_text_part = TextContentPart::new("system_prompt".to_string());
871 let system_text_content_part = ContentPart::Text(system_text_part);
872
873 let system_text_message = OpenAIChatMessage {
874 role: "assistant".to_string(),
875 content: vec![system_text_content_part],
876 name: None,
877 };
878
879 let prompt = Prompt::new_rs(
880 vec![
881 MessageNum::OpenAIMessageV1(text_message),
882 MessageNum::OpenAIMessageV1(image_message),
883 ],
884 "gpt-4o",
885 Provider::OpenAI,
886 vec![MessageNum::OpenAIMessageV1(system_text_message)],
887 None,
888 None,
889 ResponseType::Null,
890 )
891 .unwrap();
892
893 if let MessageNum::OpenAIMessageV1(msg) = &prompt.request.messages()[1] {
895 if let ContentPart::Text(text_part) = &msg.content[0] {
896 assert_eq!(text_part.text, "What company is this logo from?");
897 } else {
898 panic!("Expected TextContentPart for the first user message");
899 }
900 } else {
901 panic!("Expected OpenAIMessageV1 for the first user message");
902 }
903
904 if let MessageNum::OpenAIMessageV1(msg) = &prompt.request.messages()[2] {
906 if let ContentPart::ImageUrl(image_url) = &msg.content[0] {
907 assert_eq!(image_url.image_url.url, "https://iili.io/3Hs4FMg.png");
908 assert_eq!(image_url.r#type, "image_url");
909 } else {
910 panic!("Expected ContentPart::Image for the second user message");
911 }
912 } else {
913 panic!("Expected OpenAIMessageV1 for the second user message");
914 }
915 }
916
917 #[test]
918 fn test_document_prompt() {
919 let text_message = create_openai_chat_message();
920 let file_message = create_openai_file_message();
921 let system_message = create_system_openai_chat_message();
922
923 let prompt = Prompt::new_rs(
924 vec![
925 MessageNum::OpenAIMessageV1(text_message),
926 MessageNum::OpenAIMessageV1(file_message),
927 ],
928 "gpt-4o",
929 Provider::OpenAI,
930 vec![MessageNum::OpenAIMessageV1(system_message)],
931 None,
932 None,
933 ResponseType::Null,
934 )
935 .unwrap();
936
937 if let MessageNum::OpenAIMessageV1(msg) = &prompt.request.messages()[2] {
939 if let ContentPart::FileContent(file_content) = &msg.content[0] {
940 assert_eq!(file_content.file.file_id.as_ref().unwrap(), "fileid");
941 assert_eq!(file_content.file.filename.as_ref().unwrap(), "filename");
942 } else {
943 panic!("Expected ContentPart::FileContent for the second user message");
944 }
945 } else {
946 panic!("Expected OpenAIMessageV1 for the first user message");
947 }
948 }
949
950 #[test]
951 fn test_response_format_score() {
952 let text_message = create_openai_chat_message();
953 let prompt = Prompt::new_rs(
954 vec![MessageNum::OpenAIMessageV1(text_message)],
955 "gpt-4o",
956 Provider::OpenAI,
957 vec![],
958 None,
959 Some(Score::get_structured_output_schema()),
960 ResponseType::Null,
961 )
962 .unwrap();
963
964 assert!(prompt.response_json_schema().is_some());
966 }
967
968 #[test]
969 fn test_anthropic_text_message_binding() {
970 let text_block =
971 TextBlockParam::new_rs("Test prompt. ${param1} ${param2}".to_string(), None, None);
972 let message = MessageParam {
973 role: "user".to_string(),
974 content: vec![ContentBlockParam {
975 inner: crate::anthropic::v1::request::ContentBlock::Text(text_block),
976 }],
977 };
978
979 let prompt = Prompt::new_rs(
980 vec![MessageNum::AnthropicMessageV1(message)],
981 "claude-3-5-sonnet-20241022",
982 Provider::Anthropic,
983 vec![],
984 None,
985 None,
986 ResponseType::Null,
987 )
988 .unwrap();
989
990 assert_eq!(prompt.request.messages().len(), 1);
991 assert_eq!(prompt.parameters.len(), 2);
992
993 let mut parameters = prompt.parameters.clone();
994 parameters.sort();
995 assert_eq!(parameters[0], "param1");
996 assert_eq!(parameters[1], "param2");
997
998 let bound_msg = prompt.request.messages()[0]
1000 .bind("param1", "Value1")
1001 .unwrap();
1002 let bound_msg = bound_msg.bind("param2", "Value2").unwrap();
1003
1004 match bound_msg {
1005 MessageNum::AnthropicMessageV1(msg) => {
1006 if let crate::anthropic::v1::request::ContentBlock::Text(text_block) =
1007 &msg.content[0].inner
1008 {
1009 assert_eq!(text_block.text, "Test prompt. Value1 Value2");
1010 } else {
1011 panic!("Expected TextBlockParam");
1012 }
1013 }
1014 _ => panic!("Expected AnthropicMessageV1"),
1015 }
1016 }
1017
1018 #[test]
1019 fn test_anthropic_url_image_prompt() {
1020 let text_message = create_anthropic_text_message();
1021 let image_message = create_anthropic_url_image_message();
1022 let system_message = create_anthropic_system_message();
1023
1024 let prompt = Prompt::new_rs(
1025 vec![
1026 MessageNum::AnthropicMessageV1(text_message),
1027 MessageNum::AnthropicMessageV1(image_message),
1028 ],
1029 "claude-3-5-sonnet-20241022",
1030 Provider::Anthropic,
1031 vec![MessageNum::AnthropicMessageV1(system_message)],
1032 None,
1033 None,
1034 ResponseType::Null,
1035 )
1036 .unwrap();
1037
1038 if let MessageNum::AnthropicMessageV1(msg) = &prompt.request.messages()[0] {
1040 if let crate::anthropic::v1::request::ContentBlock::Text(text_block) =
1041 &msg.content[0].inner
1042 {
1043 assert_eq!(text_block.text, "What company is this logo from?");
1044 } else {
1045 panic!("Expected TextBlock for first message");
1046 }
1047 } else {
1048 panic!("Expected AnthropicMessageV1");
1049 }
1050
1051 if let MessageNum::AnthropicMessageV1(msg) = &prompt.request.messages()[1] {
1053 if let crate::anthropic::v1::request::ContentBlock::Image(image_block) =
1054 &msg.content[0].inner
1055 {
1056 match &image_block.source {
1057 crate::anthropic::v1::request::ImageSource::Url(url_source) => {
1058 assert_eq!(url_source.url, "https://iili.io/3Hs4FMg.png");
1059 assert_eq!(url_source.r#type, "url");
1060 }
1061 _ => panic!("Expected URL image source"),
1062 }
1063 assert_eq!(image_block.r#type, "image");
1064 } else {
1065 panic!("Expected ImageBlock for second message");
1066 }
1067 } else {
1068 panic!("Expected AnthropicMessageV1");
1069 }
1070 }
1071
1072 #[test]
1073 fn test_anthropic_base64_image_prompt() {
1074 let text_message = create_anthropic_text_message();
1075 let image_message = create_anthropic_base64_image_message();
1076
1077 let prompt = Prompt::new_rs(
1078 vec![
1079 MessageNum::AnthropicMessageV1(text_message),
1080 MessageNum::AnthropicMessageV1(image_message),
1081 ],
1082 "claude-3-5-sonnet-20241022",
1083 Provider::Anthropic,
1084 vec![],
1085 None,
1086 None,
1087 ResponseType::Null,
1088 )
1089 .unwrap();
1090
1091 if let MessageNum::AnthropicMessageV1(msg) = &prompt.request.messages()[1] {
1093 if let crate::anthropic::v1::request::ContentBlock::Image(image_block) =
1094 &msg.content[0].inner
1095 {
1096 match &image_block.source {
1097 crate::anthropic::v1::request::ImageSource::Base64(base64_source) => {
1098 assert_eq!(base64_source.media_type, "image/png");
1099 assert_eq!(base64_source.data, "base64data");
1100 assert_eq!(base64_source.r#type, "base64");
1101 }
1102 _ => panic!("Expected Base64 image source"),
1103 }
1104 } else {
1105 panic!("Expected ImageBlock");
1106 }
1107 } else {
1108 panic!("Expected AnthropicMessageV1");
1109 }
1110 }
1111
1112 #[test]
1114 fn test_anthropic_base64_pdf_document_prompt() {
1115 let text_message = create_anthropic_text_message();
1116 let pdf_message = create_anthropic_base64_pdf_message();
1117 let system_message = create_anthropic_system_message();
1118
1119 let prompt = Prompt::new_rs(
1120 vec![
1121 MessageNum::AnthropicMessageV1(text_message),
1122 MessageNum::AnthropicMessageV1(pdf_message),
1123 ],
1124 "claude-3-5-sonnet-20241022",
1125 Provider::Anthropic,
1126 vec![MessageNum::AnthropicMessageV1(system_message)],
1127 None,
1128 None,
1129 ResponseType::Null,
1130 )
1131 .unwrap();
1132
1133 if let MessageNum::AnthropicMessageV1(msg) = &prompt.request.messages()[1] {
1135 if let crate::anthropic::v1::request::ContentBlock::Document(document_block) =
1136 &msg.content[0].inner
1137 {
1138 match &document_block.source {
1139 crate::anthropic::v1::request::DocumentSource::Base64(pdf_source) => {
1140 assert_eq!(pdf_source.media_type, "application/pdf");
1141 assert_eq!(pdf_source.data, "base64pdfdata");
1142 assert_eq!(pdf_source.r#type, "base64");
1143 }
1144 _ => panic!("Expected Base64 PDF source"),
1145 }
1146 assert_eq!(document_block.r#type, "document");
1147 assert_eq!(document_block.title.as_ref().unwrap(), "test_document.pdf");
1148 } else {
1149 panic!("Expected DocumentBlock");
1150 }
1151 } else {
1152 panic!("Expected AnthropicMessageV1");
1153 }
1154 }
1155
1156 #[test]
1158 fn test_anthropic_url_pdf_document_prompt() {
1159 let text_message = create_anthropic_text_message();
1160 let pdf_message = create_anthropic_url_pdf_message();
1161
1162 let prompt = Prompt::new_rs(
1163 vec![
1164 MessageNum::AnthropicMessageV1(text_message),
1165 MessageNum::AnthropicMessageV1(pdf_message),
1166 ],
1167 "claude-3-5-sonnet-20241022",
1168 Provider::Anthropic,
1169 vec![],
1170 None,
1171 None,
1172 ResponseType::Null,
1173 )
1174 .unwrap();
1175
1176 if let MessageNum::AnthropicMessageV1(msg) = &prompt.request.messages()[1] {
1178 if let crate::anthropic::v1::request::ContentBlock::Document(document_block) =
1179 &msg.content[0].inner
1180 {
1181 match &document_block.source {
1182 crate::anthropic::v1::request::DocumentSource::Url(url_source) => {
1183 assert_eq!(url_source.url, "https://example.com/document.pdf");
1184 assert_eq!(url_source.r#type, "url");
1185 }
1186 _ => panic!("Expected URL PDF source"),
1187 }
1188 } else {
1189 panic!("Expected DocumentBlock");
1190 }
1191 } else {
1192 panic!("Expected AnthropicMessageV1");
1193 }
1194 }
1195
1196 #[test]
1198 fn test_anthropic_plain_text_document_prompt() {
1199 let text_message = create_anthropic_text_message();
1200 let text_doc_message = create_anthropic_plain_text_document_message();
1201
1202 let prompt = Prompt::new_rs(
1203 vec![
1204 MessageNum::AnthropicMessageV1(text_message),
1205 MessageNum::AnthropicMessageV1(text_doc_message),
1206 ],
1207 "claude-3-5-sonnet-20241022",
1208 Provider::Anthropic,
1209 vec![],
1210 None,
1211 None,
1212 ResponseType::Null,
1213 )
1214 .unwrap();
1215
1216 if let MessageNum::AnthropicMessageV1(msg) = &prompt.request.messages()[1] {
1218 if let crate::anthropic::v1::request::ContentBlock::Document(document_block) =
1219 &msg.content[0].inner
1220 {
1221 match &document_block.source {
1222 crate::anthropic::v1::request::DocumentSource::Text(text_source) => {
1223 assert_eq!(text_source.media_type, "text/plain");
1224 assert_eq!(text_source.data, "Plain text document content");
1225 assert_eq!(text_source.r#type, "text");
1226 }
1227 _ => panic!("Expected Text document source"),
1228 }
1229 assert_eq!(
1230 document_block.context.as_ref().unwrap(),
1231 "Context for the document"
1232 );
1233 } else {
1234 panic!("Expected DocumentBlock");
1235 }
1236 } else {
1237 panic!("Expected AnthropicMessageV1");
1238 }
1239 }
1240
1241 #[test]
1243 fn test_anthropic_mixed_content_prompt() {
1244 let text_message = create_anthropic_text_message();
1245 let pdf_message = create_anthropic_base64_pdf_message();
1246 let text_doc_message = create_anthropic_plain_text_document_message();
1247 let system_message = create_anthropic_system_message();
1248
1249 let prompt = Prompt::new_rs(
1250 vec![
1251 MessageNum::AnthropicMessageV1(text_message),
1252 MessageNum::AnthropicMessageV1(pdf_message),
1253 MessageNum::AnthropicMessageV1(text_doc_message),
1254 ],
1255 "claude-3-5-sonnet-20241022",
1256 Provider::Anthropic,
1257 vec![MessageNum::AnthropicMessageV1(system_message)],
1258 None,
1259 None,
1260 ResponseType::Null,
1261 )
1262 .unwrap();
1263
1264 assert_eq!(prompt.request.messages().len(), 3);
1265 assert_eq!(prompt.request.system_instructions().len(), 1);
1266 assert_eq!(prompt.provider, Provider::Anthropic);
1267 assert_eq!(prompt.model, "claude-3-5-sonnet-20241022");
1268 }
1269
1270 #[test]
1272 fn test_gemini_chat_message() {
1273 let text = Part::from_text("Test prompt. ${param1} ${param2}".to_string());
1274 let message = GeminiContent {
1275 role: "user".to_string(),
1276 parts: vec![text],
1277 };
1278
1279 let prompt = Prompt::new_rs(
1280 vec![MessageNum::GeminiContentV1(message)],
1281 "gemini-1.5-pro",
1282 Provider::Google,
1283 vec![],
1284 None,
1285 None,
1286 ResponseType::Null,
1287 )
1288 .unwrap();
1289
1290 assert_eq!(prompt.request.messages().len(), 1);
1291 assert_eq!(prompt.parameters.len(), 2);
1292
1293 let mut parameters = prompt.parameters.clone();
1294 parameters.sort();
1295 assert_eq!(parameters[0], "param1");
1296 assert_eq!(parameters[1], "param2");
1297
1298 let bound_msg = prompt.request.messages()[0]
1300 .bind("param1", "Value1")
1301 .unwrap();
1302 let bound_msg = bound_msg.bind("param2", "Value2").unwrap();
1303
1304 match bound_msg {
1305 MessageNum::GeminiContentV1(msg) => {
1306 if let DataNum::Text(text_part) = &msg.parts[0].data {
1307 assert_eq!(text_part, "Test prompt. Value1 Value2");
1308 } else {
1309 panic!("Expected Text Part");
1310 }
1311 }
1312 _ => panic!("Expected GeminiContentV1"),
1313 }
1314 }
1315}