use std::collections::HashMap;
use std::fmt;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use async_trait::async_trait;
use crate::error::{Error, Result};
#[non_exhaustive]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct SummarizerCapabilities {
pub max_input_tokens: Option<usize>,
pub streaming: bool,
pub structured_output: bool,
}
impl Default for SummarizerCapabilities {
fn default() -> Self {
Self {
max_input_tokens: None,
streaming: false,
structured_output: false,
}
}
}
#[non_exhaustive]
#[derive(Debug, Clone, PartialEq, Eq, Hash, Default, serde::Serialize, serde::Deserialize)]
pub enum SummaryStyle {
#[default]
Compact,
Detailed,
Custom(String),
}
impl SummaryStyle {
#[must_use]
pub fn as_str(&self) -> std::borrow::Cow<'static, str> {
match self {
SummaryStyle::Compact => std::borrow::Cow::Borrowed("compact"),
SummaryStyle::Detailed => std::borrow::Cow::Borrowed("detailed"),
SummaryStyle::Custom(n) => std::borrow::Cow::Owned(format!("custom:{n}")),
}
}
#[must_use]
pub fn from_persisted(s: &str) -> Self {
match s {
"compact" => SummaryStyle::Compact,
"detailed" => SummaryStyle::Detailed,
other => match other.strip_prefix("custom:") {
Some(n) if !n.is_empty() => SummaryStyle::Custom(n.to_string()),
_ => SummaryStyle::Compact,
},
}
}
}
#[non_exhaustive]
#[derive(Debug, Clone, Default)]
pub struct SummarizeOpts {
pub style: SummaryStyle,
pub target_tokens: Option<usize>,
}
impl SummarizeOpts {
#[must_use]
pub fn with_style(mut self, style: SummaryStyle) -> Self {
self.style = style;
self
}
#[must_use]
pub fn with_target_tokens(mut self, n: usize) -> Self {
self.target_tokens = Some(n);
self
}
}
#[async_trait]
pub trait Summarizer: Send + Sync + fmt::Debug + 'static {
fn id(&self) -> &str;
fn capabilities(&self) -> SummarizerCapabilities {
SummarizerCapabilities::default()
}
async fn summarize(&self, opts: &SummarizeOpts, inputs: &[&str]) -> Result<String>;
fn describe(&self) -> serde_json::Value {
serde_json::Value::Null
}
}
type BoxFut<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
type FactoryFn = Arc<
dyn Fn(serde_json::Value) -> BoxFut<'static, Result<Box<dyn Summarizer>>>
+ Send
+ Sync
+ 'static,
>;
#[derive(Default, Clone)]
pub struct SummarizerRegistry {
factories: HashMap<String, FactoryFn>,
}
impl fmt::Debug for SummarizerRegistry {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("SummarizerRegistry")
.field("families", &self.factories.keys().collect::<Vec<_>>())
.finish()
}
}
impl SummarizerRegistry {
#[must_use]
pub fn empty() -> Self {
Self::default()
}
pub fn register<F, Fut>(&mut self, family: &'static str, factory: F)
where
F: Fn(serde_json::Value) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<Box<dyn Summarizer>>> + Send + 'static,
{
let f: FactoryFn = Arc::new(move |v| Box::pin(factory(v)));
self.factories.insert(family.to_string(), f);
}
pub async fn build(
&self,
family: &str,
config: serde_json::Value,
) -> Result<Box<dyn Summarizer>> {
let factory = self
.factories
.get(family)
.ok_or_else(|| Error::Config(format!("summarizer family {family:?} not registered")))?;
factory(config).await
}
}
#[cfg(test)]
mod tests {
use super::*;
#[derive(Debug)]
struct Fake;
#[async_trait]
impl Summarizer for Fake {
fn id(&self) -> &str {
"fake:1"
}
async fn summarize(&self, opts: &SummarizeOpts, inputs: &[&str]) -> Result<String> {
Ok(format!("style={} n={}", opts.style.as_str(), inputs.len()))
}
}
#[tokio::test]
async fn boxed_summarizer_dispatches() {
let b: Box<dyn Summarizer> = Box::new(Fake);
let opts = SummarizeOpts::default().with_style(SummaryStyle::Detailed);
let s = b.summarize(&opts, &["a", "b"]).await.unwrap();
assert_eq!(s, "style=detailed n=2");
assert_eq!(b.id(), "fake:1");
}
#[tokio::test]
async fn registry_round_trip() {
let mut r = SummarizerRegistry::empty();
r.register("fake", |_cfg| async {
Ok(Box::new(Fake) as Box<dyn Summarizer>)
});
let s = r.build("fake", serde_json::Value::Null).await.unwrap();
assert_eq!(s.id(), "fake:1");
}
#[tokio::test]
async fn registry_unknown_family() {
let r = SummarizerRegistry::empty();
let err = r.build("nope", serde_json::Value::Null).await.unwrap_err();
assert!(matches!(err, Error::Config(_)));
}
#[test]
fn capabilities_default_is_conservative() {
let c = SummarizerCapabilities::default();
assert!(c.max_input_tokens.is_none());
assert!(!c.streaming);
assert!(!c.structured_output);
}
#[test]
fn summary_style_default_is_compact() {
assert_eq!(SummaryStyle::default(), SummaryStyle::Compact);
assert_eq!(SummaryStyle::Compact.as_str(), "compact");
assert_eq!(SummaryStyle::Detailed.as_str(), "detailed");
}
#[test]
fn summary_style_custom_round_trips_through_persistence() {
let s = SummaryStyle::Custom("daily-recap".into());
assert_eq!(s.as_str(), "custom:daily-recap");
assert_eq!(SummaryStyle::from_persisted("custom:daily-recap"), s);
}
#[test]
fn summary_style_from_persisted_handles_known_tags() {
assert_eq!(
SummaryStyle::from_persisted("compact"),
SummaryStyle::Compact
);
assert_eq!(
SummaryStyle::from_persisted("detailed"),
SummaryStyle::Detailed
);
assert_eq!(
SummaryStyle::from_persisted("custom:"),
SummaryStyle::Compact
);
assert_eq!(
SummaryStyle::from_persisted("garbage"),
SummaryStyle::Compact
);
}
}