1use super::{prompt_message::PromptMessage, prompt_role::PromptRole};
4use bon::Builder;
5use dogma::{
6 prelude::{FromStr, String, Vec, fmt},
7 traits::Collection,
8};
9
10#[derive(Clone, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd, Builder)]
11#[builder(derive(Debug), on(String, into))]
12pub struct Prompt {
13 #[builder(default)]
14 pub messages: Vec<PromptMessage>,
15}
16
17impl Collection for Prompt {
18 type Item = PromptMessage;
19
20 fn len(&self) -> usize {
21 self.messages.len()
22 }
23}
24
25impl From<Vec<PromptMessage>> for Prompt {
26 fn from(messages: Vec<PromptMessage>) -> Self {
27 Self { messages }
28 }
29}
30
31impl FromStr for Prompt {
32 type Err = ();
33
34 fn from_str(input: &str) -> Result<Self, Self::Err> {
35 Ok(input.into())
36 }
37}
38
39impl From<&str> for Prompt {
40 fn from(input: &str) -> Self {
41 (PromptRole::User, input).into()
42 }
43}
44
45impl From<String> for Prompt {
46 fn from(input: String) -> Self {
47 (PromptRole::User, input).into()
48 }
49}
50
51impl From<(PromptRole, &str)> for Prompt {
52 fn from((role, message): (PromptRole, &str)) -> Self {
53 (role, String::from(message)).into()
54 }
55}
56
57impl From<(PromptRole, String)> for Prompt {
58 fn from((role, message): (PromptRole, String)) -> Self {
59 Prompt {
60 messages: Vec::from([(role, message).into()]),
61 }
62 }
63}
64
65impl fmt::Display for Prompt {
66 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
67 for PromptMessage(role, message) in &self.messages {
68 writeln!(f, "{}: {}", role, message)?;
69 }
70 Ok(())
71 }
72}
73
74#[cfg(feature = "openai")]
75impl TryFrom<openai::schemas::CreateCompletionRequest_Prompt> for Prompt {
76 type Error = ();
77
78 fn try_from(
79 input: openai::schemas::CreateCompletionRequest_Prompt,
80 ) -> Result<Self, Self::Error> {
81 use openai::schemas::CreateCompletionRequest_Prompt::*;
82 match input {
83 Text(prompt) => Ok(prompt.into()),
84 TextArray(prompts) => Ok(prompts.join("").into()),
85 TokenArray(_) => Err(()),
86 TokenArrayArray(_) => Err(()),
87 }
88 }
89}