use std::pin::Pin;
use derive_builder::Builder;
use serde::{Deserialize, Serialize};
use tokio_stream::Stream;
use crate::error::AnthropicError;
use crate::DEFAULT_MODEL;
#[derive(Clone, Serialize, Default, Debug, Builder, PartialEq)]
#[builder(pattern = "mutable")]
#[builder(setter(into, strip_option), default)]
#[builder(derive(Debug))]
#[builder(build_fn(error = "AnthropicError"))]
pub struct CompleteRequest {
pub prompt: String,
#[builder(default = "DEFAULT_MODEL.to_string()")]
pub model: String,
pub max_tokens_to_sample: usize,
pub stop_sequences: Option<Vec<String>>,
#[builder(default = "false")]
pub stream: bool,
}
#[derive(Debug, Deserialize, Clone, PartialEq, Serialize)]
pub struct CompleteResponse {
pub completion: String,
pub stop_reason: Option<StopReason>,
}
#[derive(Copy, Clone, Serialize, Deserialize, Debug, PartialEq, Eq, Default)]
#[serde(rename_all = "snake_case")]
pub enum Role {
#[default]
User,
Assistant,
}
#[derive(Clone, Serialize, Deserialize, Debug, PartialEq, Eq)]
#[serde(rename_all = "snake_case", tag = "type")]
pub enum ContentBlock {
Text { text: String },
Image { source: String, media_type: String, data: String },
}
#[derive(Clone, Serialize, Deserialize, Debug, Builder, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub struct Message {
pub role: Role,
pub content: Vec<ContentBlock>,
}
#[derive(Clone, Serialize, Default, Debug, Builder, PartialEq)]
#[builder(pattern = "mutable")]
#[builder(setter(into, strip_option), default)]
#[builder(derive(Debug))]
#[builder(build_fn(error = "AnthropicError"))]
pub struct MessagesRequest {
pub messages: Vec<Message>,
#[serde(skip_serializing_if = "String::is_empty")]
pub system: String,
#[builder(default = "DEFAULT_MODEL.to_string()")]
pub model: String,
pub max_tokens: usize,
#[serde(skip_serializing_if = "Vec::is_empty")]
pub stop_sequences: Vec<String>,
#[builder(default = "false")]
pub stream: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_k: Option<usize>,
}
#[derive(Copy, Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum StopReason {
EndTurn,
MaxTokens,
StopSequence,
}
#[derive(Copy, Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
pub struct Usage {
pub input_tokens: usize,
pub output_tokens: usize,
}
#[derive(Debug, Deserialize, Clone, PartialEq, Eq, Serialize)]
pub struct MessagesResponse {
pub id: String,
pub r#type: String,
pub role: Role,
pub content: Vec<ContentBlock>,
pub model: String,
pub stop_reason: Option<StopReason>,
pub stop_sequence: Option<String>,
pub usage: Usage,
}
pub type CompleteResponseStream = Pin<Box<dyn Stream<Item = Result<CompleteResponse, AnthropicError>> + Send>>;
pub type MessagesResponseStream = Pin<Box<dyn Stream<Item = Result<MessagesStreamEvent, AnthropicError>> + Send>>;
#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
#[serde(rename_all = "snake_case", tag = "type")]
pub enum ContentBlockDelta {
TextDelta { text: String },
}
#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
pub struct MessageDeltaUsage {
output_tokens: usize,
}
#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
pub struct MessageDelta {
pub stop_reason: Option<StopReason>,
pub stop_sequence: Option<String>,
}
#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
#[serde(rename_all = "snake_case", tag = "type")]
pub enum MessagesStreamEvent {
MessageStart { message: Message },
ContentBlockStart { index: usize, content_block: ContentBlock },
ContentBlockDelta { index: usize, delta: ContentBlockDelta },
ContentBlockStop { index: usize },
MessageDelta { delta: MessageDelta, usage: MessageDeltaUsage },
MessageStop,
}
#[derive(Debug, Deserialize, Clone, PartialEq, Eq, Serialize)]
pub struct StreamError {
#[serde(rename = "type")]
pub error_type: String,
pub message: String,
}
impl std::fmt::Display for StreamError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_fmt(format_args!("Error ({}): {}", self.error_type, self.message))
}
}