mod assistant_message;
mod developer_message;
mod file_content;
mod rich_content;
mod simple_content;
mod system_message;
mod tool_message;
mod user_message;
pub use assistant_message::*;
pub use developer_message::*;
pub use file_content::*;
pub use rich_content::*;
pub use simple_content::*;
pub use system_message::*;
pub use tool_message::*;
pub use user_message::*;
#[cfg(test)]
mod assistant_message_tests;
use crate::functions;
use functions::expression::{ExpressionError, FromStarlarkValue};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use starlark::values::dict::DictRef as StarlarkDictRef;
use starlark::values::{UnpackValue, Value as StarlarkValue};
pub mod prompt {
use super::Message;
use schemars::JsonSchema;
fn is_chain(a: &Message, b: &Message) -> bool {
match (a, b) {
(Message::Developer(a), Message::Developer(b)) => {
!a.has_name() || !b.has_name() || a.name == b.name
}
(Message::System(a), Message::System(b)) => {
!a.has_name() || !b.has_name() || a.name == b.name
}
(Message::User(a), Message::User(b)) => {
!a.has_name() || !b.has_name() || a.name == b.name
}
_ => false,
}
}
fn push(target: &mut Message, other: &Message) {
match (target, other) {
(Message::Developer(t), Message::Developer(o)) => t.push(o),
(Message::System(t), Message::System(o)) => t.push(o),
(Message::User(t), Message::User(o)) => t.push(o),
_ => unreachable!(),
}
}
pub fn prepare(messages: &mut Vec<Message>) {
messages.iter_mut().for_each(Message::prepare);
let has_chain = messages.windows(2).any(|w| is_chain(&w[0], &w[1]));
if !has_chain {
return;
}
let mut merged = Vec::with_capacity(messages.len());
for msg in messages.drain(..) {
if let Some(last) = merged.last_mut() {
if is_chain(last, &msg) {
push(last, &msg);
continue;
}
}
merged.push(msg);
}
*messages = merged;
prepare(messages);
}
pub fn id(messages: &[Message]) -> String {
let mut hasher = twox_hash::XxHash3_128::with_seed(0);
hasher.write(serde_json::to_string(messages).unwrap().as_bytes());
format!("{:0>22}", base62::encode(hasher.finish_128()))
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema, arbitrary::Arbitrary)]
#[serde(tag = "role")]
#[schemars(rename = "agent.completions.message.Message")]
pub enum Message {
#[schemars(title = "Developer")]
#[serde(rename = "developer")]
Developer(DeveloperMessage),
#[schemars(title = "System")]
#[serde(rename = "system")]
System(SystemMessage),
#[schemars(title = "User")]
#[serde(rename = "user")]
User(UserMessage),
#[schemars(title = "Assistant")]
#[serde(rename = "assistant")]
Assistant(AssistantMessage),
#[schemars(title = "Tool")]
#[serde(rename = "tool")]
Tool(ToolMessage),
}
impl Message {
pub fn prepare(&mut self) {
match self {
Message::Developer(msg) => msg.prepare(),
Message::System(msg) => msg.prepare(),
Message::User(msg) => msg.prepare(),
Message::Assistant(msg) => msg.prepare(),
Message::Tool(msg) => msg.prepare(),
}
}
}
impl FromStarlarkValue for Message {
fn from_starlark_value(
value: &StarlarkValue,
) -> Result<Self, ExpressionError> {
let dict = StarlarkDictRef::from_value(*value).ok_or_else(|| {
ExpressionError::StarlarkConversionError(
"Message: expected dict".into(),
)
})?;
let mut role = None;
for (k, v) in dict.iter() {
if let Ok(Some("role")) = <&str as UnpackValue>::unpack_value(k) {
role = Some(
<&str as UnpackValue>::unpack_value(v)
.map_err(|e| {
ExpressionError::StarlarkConversionError(
e.to_string(),
)
})?
.ok_or_else(|| {
ExpressionError::StarlarkConversionError(
"Message: expected string role".into(),
)
})?,
);
break;
}
}
let role = role.ok_or_else(|| {
ExpressionError::StarlarkConversionError(
"Message: missing role".into(),
)
})?;
match role {
"developer" => DeveloperMessage::from_starlark_value(value)
.map(Message::Developer),
"system" => {
SystemMessage::from_starlark_value(value).map(Message::System)
}
"user" => {
UserMessage::from_starlark_value(value).map(Message::User)
}
"assistant" => AssistantMessage::from_starlark_value(value)
.map(Message::Assistant),
"tool" => {
ToolMessage::from_starlark_value(value).map(Message::Tool)
}
_ => Err(ExpressionError::StarlarkConversionError(format!(
"Message: unknown role: {}",
role
))),
}
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema, arbitrary::Arbitrary)]
#[serde(tag = "role")]
#[schemars(rename = "agent.completions.message.MessageExpression")]
pub enum MessageExpression {
#[schemars(title = "Developer")]
#[serde(rename = "developer")]
Developer(DeveloperMessageExpression),
#[schemars(title = "System")]
#[serde(rename = "system")]
System(SystemMessageExpression),
#[schemars(title = "User")]
#[serde(rename = "user")]
User(UserMessageExpression),
#[schemars(title = "Assistant")]
#[serde(rename = "assistant")]
Assistant(AssistantMessageExpression),
#[schemars(title = "Tool")]
#[serde(rename = "tool")]
Tool(ToolMessageExpression),
}
impl MessageExpression {
pub fn compile(
self,
params: &functions::expression::Params,
) -> Result<Message, functions::expression::ExpressionError> {
match self {
MessageExpression::Developer(msg) => {
msg.compile(params).map(Message::Developer)
}
MessageExpression::System(msg) => {
msg.compile(params).map(Message::System)
}
MessageExpression::User(msg) => {
msg.compile(params).map(Message::User)
}
MessageExpression::Assistant(msg) => {
msg.compile(params).map(Message::Assistant)
}
MessageExpression::Tool(msg) => {
msg.compile(params).map(Message::Tool)
}
}
}
}
impl FromStarlarkValue for MessageExpression {
fn from_starlark_value(
value: &StarlarkValue,
) -> Result<Self, ExpressionError> {
let dict = StarlarkDictRef::from_value(*value).ok_or_else(|| {
ExpressionError::StarlarkConversionError(
"MessageExpression: expected dict".into(),
)
})?;
let mut role = None;
for (k, v) in dict.iter() {
if let Ok(Some("role")) = <&str as UnpackValue>::unpack_value(k) {
role = Some(
<&str as UnpackValue>::unpack_value(v)
.map_err(|e| {
ExpressionError::StarlarkConversionError(
e.to_string(),
)
})?
.ok_or_else(|| {
ExpressionError::StarlarkConversionError(
"MessageExpression: expected string role"
.into(),
)
})?,
);
break;
}
}
let role = role.ok_or_else(|| {
ExpressionError::StarlarkConversionError(
"MessageExpression: missing role".into(),
)
})?;
match role {
"developer" => {
DeveloperMessageExpression::from_starlark_value(value)
.map(MessageExpression::Developer)
}
"system" => SystemMessageExpression::from_starlark_value(value)
.map(MessageExpression::System),
"user" => UserMessageExpression::from_starlark_value(value)
.map(MessageExpression::User),
"assistant" => {
AssistantMessageExpression::from_starlark_value(value)
.map(MessageExpression::Assistant)
}
"tool" => ToolMessageExpression::from_starlark_value(value)
.map(MessageExpression::Tool),
_ => Err(ExpressionError::StarlarkConversionError(format!(
"MessageExpression: unknown role: {}",
role
))),
}
}
}
crate::functions::expression::impl_from_special_unsupported!(MessageExpression,);
impl crate::functions::expression::FromSpecial
for Vec<crate::functions::expression::WithExpression<MessageExpression>>
{
fn from_special(
_special: &crate::functions::expression::Special,
_params: &crate::functions::expression::Params,
) -> Result<Self, crate::functions::expression::ExpressionError> {
Err(crate::functions::expression::ExpressionError::UnsupportedSpecial)
}
}