use crate::content::{Annotations, EmbeddedResource, ImageContent};
use crate::handler::PromptError;
use crate::resource::ResourceContents;
use base64::engine::{general_purpose::STANDARD as BASE64_STANDARD, Engine};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct Prompt {
pub name: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub arguments: Option<Vec<PromptArgument>>,
}
impl Prompt {
pub fn new<N, D>(
name: N,
description: Option<D>,
arguments: Option<Vec<PromptArgument>>,
) -> Self
where
N: Into<String>,
D: Into<String>,
{
Prompt {
name: name.into(),
description: description.map(Into::into),
arguments,
}
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct PromptArgument {
pub name: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub required: Option<bool>,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum PromptMessageRole {
User,
Assistant,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "lowercase")]
pub enum PromptMessageContent {
Text { text: String },
Image { image: ImageContent },
Resource { resource: EmbeddedResource },
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct PromptMessage {
pub role: PromptMessageRole,
pub content: PromptMessageContent,
}
impl PromptMessage {
pub fn new_text<S: Into<String>>(role: PromptMessageRole, text: S) -> Self {
Self {
role,
content: PromptMessageContent::Text { text: text.into() },
}
}
pub fn new_image<S: Into<String>>(
role: PromptMessageRole,
data: S,
mime_type: S,
annotations: Option<Annotations>,
) -> Result<Self, PromptError> {
let data = data.into();
let mime_type = mime_type.into();
BASE64_STANDARD.decode(&data).map_err(|_| {
PromptError::InvalidParameters("Image data must be valid base64".to_string())
})?;
if !mime_type.starts_with("image/") {
return Err(PromptError::InvalidParameters(
"MIME type must be a valid image type (e.g. image/jpeg)".to_string(),
));
}
Ok(Self {
role,
content: PromptMessageContent::Image {
image: ImageContent {
data,
mime_type,
annotations,
},
},
})
}
pub fn new_resource(
role: PromptMessageRole,
uri: String,
mime_type: String,
text: Option<String>,
annotations: Option<Annotations>,
) -> Self {
let resource_contents = ResourceContents::TextResourceContents {
uri,
mime_type: Some(mime_type),
text: text.unwrap_or_default(),
};
Self {
role,
content: PromptMessageContent::Resource {
resource: EmbeddedResource {
resource: resource_contents,
annotations,
},
},
}
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct PromptTemplate {
pub id: String,
pub template: String,
pub arguments: Vec<PromptArgumentTemplate>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct PromptArgumentTemplate {
pub name: String,
pub description: Option<String>,
pub required: Option<bool>,
}