use std::sync::Arc;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use crate::error::ContextError;
use crate::provider::{ChatRequest, Message, ModelName, ToolChoice, ToolSpec};
use crate::tool::ToolRegistry;
pub type ContextResult<T> = std::result::Result<T, ContextError>;
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct ContextInput {
pub user_message: Option<String>,
pub session_id: Option<String>,
pub metadata: Value,
}
impl ContextInput {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn with_user_message(mut self, message: impl Into<String>) -> Self {
self.user_message = Some(message.into());
self
}
#[must_use]
pub fn with_session_id(mut self, session_id: impl Into<String>) -> Self {
self.session_id = Some(session_id.into());
self
}
#[must_use]
pub fn with_metadata(mut self, metadata: Value) -> Self {
self.metadata = metadata;
self
}
}
#[derive(Debug, Clone, Default)]
pub struct ContextOutput {
messages: Vec<Message>,
}
impl ContextOutput {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn from_messages(messages: Vec<Message>) -> Self {
Self { messages }
}
#[must_use]
pub fn messages(&self) -> &[Message] {
&self.messages
}
#[must_use]
pub fn into_messages(self) -> Vec<Message> {
self.messages
}
pub fn extend(&mut self, messages: impl IntoIterator<Item = Message>) {
self.messages.extend(messages);
}
#[must_use]
pub fn into_request(self, model: ModelName) -> ChatRequest {
ChatRequest {
model,
messages: self.messages,
tools: Vec::new(),
tool_choice: ToolChoice::default(),
response_format: None,
temperature: None,
top_p: None,
max_output_tokens: None,
stop: Vec::new(),
metadata: Value::Null,
}
}
#[must_use]
pub fn into_request_with_tools(self, model: ModelName, tools: &[ToolSpec]) -> ChatRequest {
ChatRequest {
model,
messages: self.messages,
tools: tools.to_vec(),
tool_choice: ToolChoice::default(),
response_format: None,
temperature: None,
top_p: None,
max_output_tokens: None,
stop: Vec::new(),
metadata: Value::Null,
}
}
}
#[async_trait]
pub trait ContextAdapter: Send + Sync {
fn name(&self) -> &str;
async fn produce(&self, input: &ContextInput) -> ContextResult<Vec<Message>>;
}
pub struct StaticAdapter {
name: String,
messages: Vec<Message>,
}
impl StaticAdapter {
#[must_use]
pub fn system(text: impl Into<String>) -> Self {
Self {
name: "system".to_owned(),
messages: vec![Message::system_text(text)],
}
}
#[must_use]
pub fn user(text: impl Into<String>) -> Self {
Self {
name: "user".to_owned(),
messages: vec![Message::user_text(text)],
}
}
#[must_use]
pub fn messages(name: impl Into<String>, messages: Vec<Message>) -> Self {
Self {
name: name.into(),
messages,
}
}
}
#[async_trait]
impl ContextAdapter for StaticAdapter {
fn name(&self) -> &str {
&self.name
}
async fn produce(&self, _input: &ContextInput) -> ContextResult<Vec<Message>> {
Ok(self.messages.clone())
}
}
pub struct FunctionAdapter<F> {
name: String,
handler: F,
}
impl<F, Fut> FunctionAdapter<F>
where
F: Fn(ContextInput) -> Fut + Send + Sync + 'static,
Fut: std::future::Future<Output = ContextResult<Vec<Message>>> + Send + 'static,
{
#[must_use]
pub fn new(name: impl Into<String>, handler: F) -> Self {
Self {
name: name.into(),
handler,
}
}
}
#[async_trait]
impl<F, Fut> ContextAdapter for FunctionAdapter<F>
where
F: Fn(ContextInput) -> Fut + Send + Sync + 'static,
Fut: std::future::Future<Output = ContextResult<Vec<Message>>> + Send + 'static,
{
fn name(&self) -> &str {
&self.name
}
async fn produce(&self, input: &ContextInput) -> ContextResult<Vec<Message>> {
(self.handler)(input.clone()).await
}
}
#[derive(Clone, Default)]
pub struct ContextFactory {
adapters: Vec<Arc<dyn ContextAdapter>>,
}
impl ContextFactory {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn register<A>(&mut self, adapter: A)
where
A: ContextAdapter + 'static,
{
self.adapters.push(Arc::new(adapter));
}
pub fn register_arc(&mut self, adapter: Arc<dyn ContextAdapter>) {
self.adapters.push(adapter);
}
#[must_use]
pub fn len(&self) -> usize {
self.adapters.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.adapters.is_empty()
}
pub fn adapter_names(&self) -> impl Iterator<Item = &str> {
self.adapters.iter().map(|a| a.name())
}
pub async fn build(&self, input: &ContextInput) -> ContextResult<ContextOutput> {
let mut output = ContextOutput::new();
for adapter in &self.adapters {
let messages =
adapter
.produce(input)
.await
.map_err(|e| ContextError::AdapterFailed {
adapter: adapter.name().to_owned(),
message: e.to_string(),
})?;
output.extend(messages);
}
Ok(output)
}
pub async fn build_request(
&self,
input: &ContextInput,
model: ModelName,
tool_registry: Option<&ToolRegistry>,
) -> ContextResult<ChatRequest> {
let output = self.build(input).await?;
let request = if let Some(registry) = tool_registry {
let specs = registry.specs();
output.into_request_with_tools(model, &specs)
} else {
output.into_request(model)
};
Ok(request)
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn context_input_should_support_builder_pattern() {
let input = ContextInput::new()
.with_user_message("Hello")
.with_session_id("session_123")
.with_metadata(json!({"key": "value"}));
assert_eq!(input.user_message, Some("Hello".to_owned()));
assert_eq!(input.session_id, Some("session_123".to_owned()));
assert_eq!(input.metadata, json!({"key": "value"}));
}
#[test]
fn context_output_should_be_empty_when_new() {
let output = ContextOutput::new();
assert!(output.messages().is_empty());
}
#[test]
fn context_output_should_extend_messages() {
let mut output = ContextOutput::new();
output.extend(vec![
Message::system_text("System"),
Message::user_text("User"),
]);
assert_eq!(output.messages().len(), 2);
}
#[test]
fn context_output_should_convert_to_request() {
let output = ContextOutput::from_messages(vec![
Message::system_text("System"),
Message::user_text("User"),
]);
let request = output.into_request(ModelName::new("gpt-4"));
assert_eq!(request.model.as_str(), "gpt-4");
assert_eq!(request.messages.len(), 2);
assert!(request.tools.is_empty());
}
#[test]
fn context_output_should_convert_to_request_with_tools() {
let output = ContextOutput::from_messages(vec![Message::user_text("Hello")]);
let tools = vec![ToolSpec::new("echo", "Echo tool", json!({}))];
let request = output.into_request_with_tools(ModelName::new("gpt-4"), &tools);
assert_eq!(request.tools.len(), 1);
assert_eq!(request.tools[0].name, "echo");
}
#[test]
fn context_factory_should_be_empty_when_new() {
let factory = ContextFactory::new();
assert!(factory.is_empty());
assert_eq!(factory.len(), 0);
}
#[test]
fn context_factory_should_register_adapters() {
let mut factory = ContextFactory::new();
factory.register(StaticAdapter::system("System prompt"));
factory.register(StaticAdapter::user("User message"));
assert_eq!(factory.len(), 2);
}
#[test]
fn context_factory_should_list_adapter_names() {
let mut factory = ContextFactory::new();
factory.register(StaticAdapter::system("System"));
factory.register(StaticAdapter::user("User"));
let names: Vec<&str> = factory.adapter_names().collect();
assert_eq!(names, vec!["system", "user"]);
}
#[tokio::test]
async fn context_factory_should_build_output_in_order() {
let mut factory = ContextFactory::new();
factory.register(StaticAdapter::system("First"));
factory.register(StaticAdapter::user("Second"));
let input = ContextInput::new();
let output = factory.build(&input).await.unwrap();
assert_eq!(output.messages().len(), 2);
}
#[tokio::test]
async fn context_factory_should_build_request_with_tools() {
let mut factory = ContextFactory::new();
factory.register(StaticAdapter::system("You are helpful."));
let registry = ToolRegistry::new();
registry.register(crate::tool::FunctionTool::new(
"echo",
"Echo",
json!({}),
|_: Value| async { Ok(Value::Null) },
));
let input = ContextInput::new().with_user_message("Hello");
let request = factory
.build_request(&input, ModelName::new("gpt-4"), Some(®istry))
.await
.unwrap();
assert_eq!(request.messages.len(), 1);
assert_eq!(request.tools.len(), 1);
}
#[tokio::test]
async fn static_adapter_should_produce_system_message() {
let adapter = StaticAdapter::system("You are a helpful assistant.");
let input = ContextInput::new();
let messages = adapter.produce(&input).await.unwrap();
assert_eq!(messages.len(), 1);
assert!(matches!(messages[0], Message::System { .. }));
}
#[tokio::test]
async fn static_adapter_should_produce_user_message() {
let adapter = StaticAdapter::user("Hello");
let input = ContextInput::new();
let messages = adapter.produce(&input).await.unwrap();
assert_eq!(messages.len(), 1);
assert!(matches!(messages[0], Message::User { .. }));
}
#[tokio::test]
async fn function_adapter_should_invoke_handler() {
let adapter = FunctionAdapter::new("custom", |input: ContextInput| async move {
let msg = input.user_message.unwrap_or_default();
Ok(vec![Message::user_text(format!("Echo: {msg}"))])
});
let input = ContextInput::new().with_user_message("test");
let messages = adapter.produce(&input).await.unwrap();
assert_eq!(messages.len(), 1);
}
}