use crate::json_parser::deduplication::get_overlap_thresholds;
use crate::json_parser::deduplication::rolling_hash::RollingHashWindow;
use std::io::Write as IoWrite;
#[derive(Debug, Default, Clone)]
pub struct StreamingSession {
pub(super) state: StreamingState,
pub(super) streamed_types: HashMap<ContentType, bool>,
pub(super) current_block: ContentBlockState,
pub(super) accumulated: HashMap<(ContentType, String), String>,
pub(super) key_order: Vec<(ContentType, String)>,
pub(super) delta_sizes: HashMap<(ContentType, String), Vec<usize>>,
pub(super) max_delta_history: usize,
pub(super) current_message_id: Option<String>,
pub(super) displayed_final_messages: HashSet<String>,
pub(super) output_started_for_key: HashSet<(ContentType, String)>,
pub(super) verbose_warnings: bool,
pub(super) snapshot_repairs_count: usize,
pub(super) large_delta_count: usize,
pub(super) protocol_violations: usize,
pub(super) final_content_hash: Option<u64>,
pub(super) last_rendered: HashMap<(ContentType, String), String>,
pub(super) rendered_content_hashes: HashSet<(ContentType, String, u64)>,
pub(super) last_delta: HashMap<(ContentType, String), String>,
pub(super) consecutive_duplicates: HashMap<(ContentType, String), (usize, u64)>,
pub(super) deduplicator: DeltaDeduplicator,
pub(super) pre_rendered_message_ids: HashSet<String>,
pub(super) rendered_assistant_content_hashes: HashSet<u64>,
pub(super) tool_names: HashMap<u64, Option<String>>,
}
impl StreamingSession {
#[must_use]
pub fn new() -> Self {
Self {
max_delta_history: DEFAULT_MAX_DELTA_HISTORY,
verbose_warnings: false,
..Default::default()
}
}
#[must_use]
pub const fn with_verbose_warnings(mut self, enabled: bool) -> Self {
self.verbose_warnings = enabled;
self
}
}
impl StreamingSession {
fn reset_streaming_state_base(&mut self) {
self.state = StreamingState::Idle;
self.streamed_types.clear();
self.current_block = ContentBlockState::NotInBlock;
self.accumulated.clear();
self.key_order.clear();
self.delta_sizes.clear();
self.last_rendered.clear();
self.deduplicator.clear();
self.tool_names.clear();
}
fn on_mid_stream_restart(&mut self) {
self.protocol_violations = self.protocol_violations.saturating_add(1);
if self.verbose_warnings {
let _ = writeln!(
std::io::stderr(),
"Warning: Received MessageStart while state is Streaming. \
This indicates a non-standard agent protocol (e.g., GLM sending \
repeated MessageStart events). Preserving output_started_for_key \
to prevent prefix spam. File: state_management.rs, Line: {}",
line!()
);
}
let preserved_output_started = std::mem::take(&mut self.output_started_for_key);
let preserved_last_delta = std::mem::take(&mut self.last_delta);
let preserved_rendered_hashes = std::mem::take(&mut self.rendered_content_hashes);
let preserved_consecutive_duplicates = std::mem::take(&mut self.consecutive_duplicates);
self.reset_streaming_state_base();
self.output_started_for_key = preserved_output_started;
self.last_delta = preserved_last_delta;
self.rendered_content_hashes = preserved_rendered_hashes;
self.consecutive_duplicates = preserved_consecutive_duplicates;
}
fn on_normal_message_start(&mut self) {
self.reset_streaming_state_base();
self.output_started_for_key.clear();
self.last_delta.clear();
self.rendered_content_hashes.clear();
self.consecutive_duplicates.clear();
}
pub fn on_message_start(&mut self) {
if self.state == StreamingState::Streaming {
self.on_mid_stream_restart();
} else {
self.on_normal_message_start();
}
}
pub fn set_current_message_id(&mut self, message_id: Option<String>) {
self.current_message_id = message_id;
}
#[must_use]
pub fn get_current_message_id(&self) -> Option<&str> {
self.current_message_id.as_deref()
}
#[must_use]
pub fn is_duplicate_final_message(&self, message_id: &str) -> bool {
self.displayed_final_messages.contains(message_id)
}
pub fn mark_message_displayed(&mut self, message_id: &str) {
self.displayed_final_messages.insert(message_id.to_string());
}
pub fn mark_message_pre_rendered(&mut self, message_id: &str) {
self.pre_rendered_message_ids.insert(message_id.to_string());
}
#[must_use]
pub fn is_message_pre_rendered(&self, message_id: &str) -> bool {
self.pre_rendered_message_ids.contains(message_id)
}
#[must_use]
pub fn is_assistant_content_rendered(&self, content_hash: u64) -> bool {
self.rendered_assistant_content_hashes
.contains(&content_hash)
}
pub fn mark_assistant_content_rendered(&mut self, content_hash: u64) {
self.rendered_assistant_content_hashes.insert(content_hash);
}
pub fn on_content_block_start(&mut self, index: u64) {
let index_str = index.to_string();
self.ensure_content_block_finalized();
self.current_block = ContentBlockState::InBlock {
index: index_str,
started_output: false,
};
}
fn ensure_content_block_finalized(&mut self) -> bool {
if let ContentBlockState::InBlock { started_output, .. } = &self.current_block {
let had_output = *started_output;
self.current_block = ContentBlockState::NotInBlock;
had_output
} else {
false
}
}
fn assert_lifecycle_state(&self, expected: &[StreamingState]) {
#[cfg(debug_assertions)]
assert!(
expected.contains(&self.state),
"Invalid lifecycle state: expected {:?}, got {:?}. \
This indicates a bug in the parser's event handling.",
expected,
self.state
);
#[cfg(not(debug_assertions))]
let _ = expected;
}
pub fn on_message_stop(&mut self) -> bool {
let was_in_block = self.ensure_content_block_finalized();
self.state = StreamingState::Finalized;
self.final_content_hash = self.compute_content_hash();
if let Some(message_id) = self.current_message_id.clone() {
self.mark_message_displayed(&message_id);
}
was_in_block
}
pub fn clear_key(&mut self, content_type: ContentType, key: &str) {
let content_key = (content_type, key.to_string());
self.accumulated.remove(&content_key);
self.key_order.retain(|k| k != &content_key);
self.output_started_for_key.remove(&content_key);
self.delta_sizes.remove(&content_key);
self.last_rendered.remove(&content_key);
self.last_delta.remove(&content_key);
self.consecutive_duplicates.remove(&content_key);
self.rendered_content_hashes
.retain(|(ct, k, _hash)| !(*ct == content_type && k == key));
}
#[must_use]
pub fn has_any_streamed_content(&self) -> bool {
!self.streamed_types.is_empty()
}
}
fn update_consecutive_dup_entry(count: &mut usize, prev_hash: &mut u64, delta_hash: u64, threshold: usize) -> bool {
if *prev_hash == delta_hash {
*count = count.saturating_add(1);
*count >= threshold
} else {
*count = 1;
*prev_hash = delta_hash;
false
}
}
fn warn_if_verbose_consecutive_dup(verbose: bool, count: usize, threshold: usize, key_str: &str, delta: &str) {
if verbose {
let _ = writeln!(
std::io::stderr(),
"Warning: Dropping consecutive duplicate delta (count={count}, threshold={threshold}). \
This appears to be a resend glitch. Key: '{key_str}', Delta: {delta:?}",
);
}
}
fn is_exact_duplicate_after_message_start(
last_delta: &HashMap<(ContentType, String), String>,
accumulated: &HashMap<(ContentType, String), String>,
content_key: &(ContentType, String),
delta: &str,
) -> bool {
if let Some(last) = last_delta.get(content_key) {
if delta == last {
return accumulated.get(content_key).is_none_or(String::is_empty);
}
}
false
}
fn warn_large_delta_pattern(verbose: bool, sizes: &[usize], key: &str) {
let large_count = sizes.iter().filter(|&&s| s > snapshot_threshold()).count();
if sizes.len() >= DEFAULT_PATTERN_DETECTION_MIN_DELTAS && large_count >= DEFAULT_PATTERN_DETECTION_MIN_DELTAS && verbose {
let _ = writeln!(std::io::stderr(), "Warning: Detected pattern of {large_count} large deltas for key '{key}'. This strongly suggests a snapshot-as-delta bug.");
}
}
impl StreamingSession {
pub fn on_text_delta(&mut self, index: u64, delta: &str) -> bool {
self.on_text_delta_key(&index.to_string(), delta)
}
fn check_consecutive_duplicate(&mut self, content_key: &(ContentType, String), delta: &str, key_str: &str) -> bool {
let delta_hash = RollingHashWindow::compute_hash(delta);
let threshold = get_overlap_thresholds().consecutive_duplicate_threshold;
let Some((count, prev_hash)) = self.consecutive_duplicates.get_mut(content_key) else {
self.consecutive_duplicates.insert(content_key.clone(), (1, delta_hash));
return false;
};
let exceeded = update_consecutive_dup_entry(count, prev_hash, delta_hash, threshold);
if exceeded { warn_if_verbose_consecutive_dup(self.verbose_warnings, *count, threshold, key_str, delta); }
exceeded
}
fn track_large_delta_warning(&mut self, delta_size: usize, key: &str) {
if delta_size <= snapshot_threshold() { return; }
self.large_delta_count = self.large_delta_count.saturating_add(1);
if self.verbose_warnings {
let _ = writeln!(std::io::stderr(), "Warning: Large delta ({delta_size} chars) for key '{key}'. This may indicate unusual streaming behavior.");
}
}
fn track_delta_size(&mut self, content_key: &(ContentType, String), delta_size: usize) {
let sizes = self.delta_sizes.entry(content_key.clone()).or_default();
sizes.push(delta_size);
if sizes.len() > self.max_delta_history { sizes.remove(0); }
}
fn track_delta_metrics(&mut self, content_key: &(ContentType, String), delta: &str, key: &str) {
self.track_large_delta_warning(delta.len(), key);
self.track_delta_size(content_key, delta.len());
}
fn warn_snapshot_fallback(&self, e: &SnapshotDeltaError) {
if self.verbose_warnings {
let _ = writeln!(std::io::stderr(), "Warning: Snapshot extraction failed: {e}. Using original delta.");
}
}
fn resolve_snapshot_or_delta(&mut self, delta: &str, key: &str) -> String {
if !self.is_likely_snapshot(delta, key) { return delta.to_string(); }
match self.get_delta_from_snapshot(delta, key) {
Ok(extracted) => { self.snapshot_repairs_count = self.snapshot_repairs_count.saturating_add(1); extracted.to_string() }
Err(e) => { self.warn_snapshot_fallback(&e); delta.to_string() }
}
}
fn commit_text_delta(&mut self, content_key: (ContentType, String), actual_delta: String, delta: &str, key: &str) -> bool {
self.streamed_types.insert(ContentType::Text, true);
self.state = StreamingState::Streaming;
self.current_block = ContentBlockState::InBlock { index: key.to_string(), started_output: true };
let is_first = !self.output_started_for_key.contains(&content_key);
self.output_started_for_key.insert(content_key.clone());
self.accumulated.entry(content_key.clone()).and_modify(|buf| buf.push_str(&actual_delta)).or_insert_with(|| actual_delta);
self.last_delta.insert(content_key.clone(), delta.to_string());
if is_first { self.key_order.push(content_key); }
is_first
}
pub fn on_text_delta_key(&mut self, key: &str, delta: &str) -> bool {
self.assert_lifecycle_state(&[StreamingState::Idle, StreamingState::Streaming]);
let content_key = (ContentType::Text, key.to_string());
self.track_delta_metrics(&content_key, delta, key);
if is_exact_duplicate_after_message_start(&self.last_delta, &self.accumulated, &content_key, delta) { return false; }
if self.check_consecutive_duplicate(&content_key, delta, key) { return false; }
let actual_delta = self.resolve_snapshot_or_delta(delta, key);
warn_large_delta_pattern(self.verbose_warnings, self.delta_sizes.get(&content_key).map_or(&[], |v| v), key);
if actual_delta.is_empty() { return false; }
self.commit_text_delta(content_key, actual_delta, delta, key)
}
}
impl StreamingSession {
pub fn on_thinking_delta(&mut self, index: u64, delta: &str) -> bool {
self.on_thinking_delta_key(&index.to_string(), delta)
}
pub fn on_thinking_delta_key(&mut self, key: &str, delta: &str) -> bool {
self.streamed_types.insert(ContentType::Thinking, true);
self.state = StreamingState::Streaming;
merge_delta(
&mut self.accumulated,
&mut self.key_order,
&mut self.output_started_for_key,
ContentType::Thinking,
key,
delta,
)
}
}
impl StreamingSession {
pub fn on_tool_input_delta(&mut self, index: u64, delta: &str) {
self.streamed_types.insert(ContentType::ToolInput, true);
self.state = StreamingState::Streaming;
let _ = merge_delta(
&mut self.accumulated,
&mut self.key_order,
&mut self.output_started_for_key,
ContentType::ToolInput,
&index.to_string(),
delta,
);
}
pub fn set_tool_name(&mut self, index: u64, name: Option<String>) {
self.tool_names.insert(index, name);
}
}
impl StreamingSession {
pub fn get_accumulated(&self, content_type: ContentType, index: &str) -> Option<&str> {
self.accumulated
.get(&(content_type, index.to_string()))
.map(std::string::String::as_str)
}
#[must_use]
pub fn accumulated_keys(&self, content_type: ContentType) -> Vec<String> {
sorted_content_keys(&self.accumulated, content_type)
}
pub fn mark_rendered(&mut self, content_type: ContentType, index: &str) {
let content_key = (content_type, index.to_string());
if let Some(current) = self.accumulated.get(&content_key) {
self.last_rendered.insert(content_key, current.clone());
}
}
#[must_use]
pub fn is_content_rendered(&self, content_type: ContentType, index: &str) -> bool {
let content_key = (content_type, index.to_string());
if let Some(current) = self.accumulated.get(&content_key) {
let hash = compute_hash(current);
return self
.rendered_content_hashes
.contains(&(content_type, index.to_string(), hash));
}
false
}
#[must_use]
pub fn has_rendered_prefix(&self, content_type: ContentType, index: &str) -> bool {
let content_key = (content_type, index.to_string());
self.output_started_for_key.contains(&content_key)
}
pub fn mark_content_rendered(&mut self, content_type: ContentType, index: &str) {
self.mark_rendered(content_type, index);
let content_key = (content_type, index.to_string());
if let Some(current) = self.accumulated.get(&content_key) {
let hash = compute_hash(current);
self.rendered_content_hashes
.insert((content_type, index.to_string(), hash));
}
}
pub fn mark_content_hash_rendered(
&mut self,
content_type: ContentType,
index: &str,
content: &str,
) {
self.mark_rendered(content_type, index);
let hash = compute_hash(content);
self.rendered_content_hashes
.insert((content_type, index.to_string(), hash));
}
#[must_use]
pub fn is_content_hash_rendered(
&self,
content_type: ContentType,
index: &str,
content: &str,
) -> bool {
let hash = compute_hash(content);
self.rendered_content_hashes
.contains(&(content_type, index.to_string(), hash))
}
}