use std::sync::Arc;
use anyhow::Result;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use brainwires_core::ContentBlock;
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct PersonaContext {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub user_id: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub session_id: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub locale: Option<String>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub tools: Vec<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub last_user_message: Option<String>,
}
impl PersonaContext {
pub fn new() -> Self {
Self::default()
}
pub fn with_user_id(mut self, id: impl Into<String>) -> Self {
self.user_id = Some(id.into());
self
}
pub fn with_session_id(mut self, id: impl Into<String>) -> Self {
self.session_id = Some(id.into());
self
}
pub fn with_locale(mut self, locale: impl Into<String>) -> Self {
self.locale = Some(locale.into());
self
}
}
#[async_trait]
pub trait PersonaProvider: Send + Sync {
async fn build(&self, ctx: &PersonaContext) -> Result<Vec<ContentBlock>>;
}
pub struct StaticPersonaProvider {
text: String,
}
impl StaticPersonaProvider {
pub fn new(text: impl Into<String>) -> Self {
Self { text: text.into() }
}
}
#[async_trait]
impl PersonaProvider for StaticPersonaProvider {
async fn build(&self, _ctx: &PersonaContext) -> Result<Vec<ContentBlock>> {
if self.text.is_empty() {
return Ok(Vec::new());
}
Ok(vec![ContentBlock::Text {
text: self.text.clone(),
}])
}
}
pub struct CompositePersonaProvider {
providers: Vec<Arc<dyn PersonaProvider>>,
}
impl CompositePersonaProvider {
pub fn new(providers: Vec<Arc<dyn PersonaProvider>>) -> Self {
Self { providers }
}
pub fn push(mut self, p: Arc<dyn PersonaProvider>) -> Self {
self.providers.push(p);
self
}
}
#[async_trait]
impl PersonaProvider for CompositePersonaProvider {
async fn build(&self, ctx: &PersonaContext) -> Result<Vec<ContentBlock>> {
let mut out = Vec::new();
for p in &self.providers {
out.extend(p.build(ctx).await?);
}
Ok(out)
}
}
pub fn blocks_to_system_text(blocks: &[ContentBlock]) -> String {
let mut out = String::new();
for b in blocks {
match b {
ContentBlock::Text { text } => {
if !out.is_empty() {
out.push_str("\n\n");
}
out.push_str(text);
}
ContentBlock::Image { .. } => {
if !out.is_empty() {
out.push_str("\n\n");
}
out.push_str("[persona: image attachment omitted]");
}
ContentBlock::ToolUse { .. } | ContentBlock::ToolResult { .. } => {
}
}
}
out
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn static_provider_returns_single_block() {
let p = StaticPersonaProvider::new("you are helpful");
let blocks = p.build(&PersonaContext::new()).await.unwrap();
assert_eq!(blocks.len(), 1);
match &blocks[0] {
ContentBlock::Text { text } => assert_eq!(text, "you are helpful"),
_ => panic!("expected text block"),
}
}
#[tokio::test]
async fn static_empty_returns_empty() {
let p = StaticPersonaProvider::new("");
let blocks = p.build(&PersonaContext::new()).await.unwrap();
assert!(blocks.is_empty());
}
#[tokio::test]
async fn composite_chains_in_order() {
let a = Arc::new(StaticPersonaProvider::new("base")) as Arc<dyn PersonaProvider>;
let b = Arc::new(StaticPersonaProvider::new("addendum")) as Arc<dyn PersonaProvider>;
let composite = CompositePersonaProvider::new(vec![a, b]);
let blocks = composite.build(&PersonaContext::new()).await.unwrap();
assert_eq!(blocks.len(), 2);
match (&blocks[0], &blocks[1]) {
(ContentBlock::Text { text: t1 }, ContentBlock::Text { text: t2 }) => {
assert_eq!(t1, "base");
assert_eq!(t2, "addendum");
}
_ => panic!("expected two text blocks in order"),
}
}
#[tokio::test]
async fn composite_push_extends() {
let a = Arc::new(StaticPersonaProvider::new("first")) as Arc<dyn PersonaProvider>;
let b = Arc::new(StaticPersonaProvider::new("second")) as Arc<dyn PersonaProvider>;
let composite = CompositePersonaProvider::new(vec![a]).push(b);
let text = blocks_to_system_text(&composite.build(&PersonaContext::new()).await.unwrap());
assert_eq!(text, "first\n\nsecond");
}
#[test]
fn blocks_to_system_text_joins_and_escapes() {
let blocks = vec![
ContentBlock::Text { text: "one".into() },
ContentBlock::Text { text: "two".into() },
];
assert_eq!(blocks_to_system_text(&blocks), "one\n\ntwo");
}
#[test]
fn context_builders() {
let c = PersonaContext::new()
.with_user_id("u1")
.with_session_id("s1")
.with_locale("en-US");
assert_eq!(c.user_id.as_deref(), Some("u1"));
assert_eq!(c.session_id.as_deref(), Some("s1"));
assert_eq!(c.locale.as_deref(), Some("en-US"));
}
}