use alloc::borrow::ToOwned;
use alloc::collections::BTreeMap;
use alloc::format;
use alloc::string::String;
use crate::error::AccumulateError;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum BlockKind {
Text,
ToolUse,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ContentBlock {
pub kind: BlockKind,
pub text: String,
pub stopped: bool,
}
#[derive(Debug, Default)]
pub struct AnthropicAccumulator {
started: bool,
stopped: bool,
blocks: BTreeMap<usize, ContentBlock>,
stop_reason: Option<String>,
}
impl AnthropicAccumulator {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn message_start(&mut self) {
self.started = true;
}
pub fn content_block_start(
&mut self,
index: usize,
kind: BlockKind,
) -> Result<(), AccumulateError> {
self.ensure_active("content_block_start")?;
if matches!(self.blocks.get(&index), Some(b) if !b.stopped) {
return Err(AccumulateError::UnexpectedEvent {
got: format!("content_block_start for block {index} that is already open"),
});
}
self.blocks.insert(
index,
ContentBlock {
kind,
text: String::new(),
stopped: false,
},
);
Ok(())
}
pub fn text_delta(&mut self, index: usize, fragment: &str) -> Result<(), AccumulateError> {
self.delta(index, fragment, BlockKind::Text)
}
pub fn input_json_delta(
&mut self,
index: usize,
fragment: &str,
) -> Result<(), AccumulateError> {
self.delta(index, fragment, BlockKind::ToolUse)
}
pub fn content_block_stop(&mut self, index: usize) -> Result<(), AccumulateError> {
self.ensure_active("content_block_stop")?;
let block = self.block_mut(index)?;
block.stopped = true;
Ok(())
}
pub fn message_delta(&mut self, stop_reason: Option<&str>) {
if let Some(reason) = stop_reason {
self.stop_reason = Some(reason.to_owned());
}
}
pub fn message_stop(&mut self) {
self.stopped = true;
}
#[must_use]
pub fn stop_reason(&self) -> Option<&str> {
self.stop_reason.as_deref()
}
#[must_use]
pub fn block(&self, index: usize) -> Option<&ContentBlock> {
self.blocks.get(&index)
}
pub fn blocks(&self) -> impl Iterator<Item = (usize, &ContentBlock)> {
self.blocks.iter().map(|(&i, b)| (i, b))
}
fn ensure_active(&self, what: &str) -> Result<(), AccumulateError> {
if !self.started {
return Err(AccumulateError::UnexpectedEvent {
got: format!("{what} before message_start"),
});
}
if self.stopped {
return Err(AccumulateError::UnexpectedEvent {
got: format!("{what} after message_stop"),
});
}
Ok(())
}
fn delta(
&mut self,
index: usize,
fragment: &str,
expected: BlockKind,
) -> Result<(), AccumulateError> {
self.ensure_active("content_block_delta")?;
let block = self.block_mut(index)?;
if block.kind != expected {
return Err(AccumulateError::BlockKindMismatch {
index,
expected: kind_name(expected),
actual: kind_name(block.kind),
});
}
block.text.push_str(fragment);
Ok(())
}
fn block_mut(&mut self, index: usize) -> Result<&mut ContentBlock, AccumulateError> {
match self.blocks.get_mut(&index) {
Some(block) => Ok(block),
None => Err(AccumulateError::UnexpectedEvent {
got: format!("delta/stop for content block {index} that was never started"),
}),
}
}
}
fn kind_name(kind: BlockKind) -> &'static str {
match kind {
BlockKind::Text => "text",
BlockKind::ToolUse => "tool_use",
}
}