use super::models::*;
use super::utils::to_claude_usage;
use crate::proxy::common::client_adapter::{ClientAdapter, SignatureBufferStrategy};
use crate::proxy::mappers::estimation_calibrator::get_calibrator;
use crate::proxy::SignatureCache;
use bytes::Bytes;
use serde_json::{json, Value};
pub fn remap_function_call_args(name: &str, args: &mut Value) {
if let Some(obj) = args.as_object() {
tracing::debug!("[Streaming] Tool Call: '{}' Args: {:?}", name, obj);
}
if name == "EnterPlanMode" {
if let Some(obj) = args.as_object_mut() {
obj.clear();
}
return;
}
if let Some(obj) = args.as_object_mut() {
match name.to_lowercase().as_str() {
"grep" | "search" | "search_code_definitions" | "search_code_snippets" => {
if let Some(desc) = obj.remove("description") {
if !obj.contains_key("pattern") {
obj.insert("pattern".to_string(), desc);
tracing::debug!("[Streaming] Remapped Grep: description → pattern");
}
}
if let Some(query) = obj.remove("query") {
if !obj.contains_key("pattern") {
obj.insert("pattern".to_string(), query);
tracing::debug!("[Streaming] Remapped Grep: query → pattern");
}
}
if !obj.contains_key("path") {
if let Some(paths) = obj.remove("paths") {
let path_str = if let Some(arr) = paths.as_array() {
arr.first()
.and_then(|v| v.as_str())
.unwrap_or(".")
.to_string()
} else if let Some(s) = paths.as_str() {
s.to_string()
} else {
".".to_string()
};
obj.insert("path".to_string(), serde_json::json!(path_str));
tracing::debug!(
"[Streaming] Remapped Grep: paths → path(\"{}\")",
path_str
);
} else {
obj.insert("path".to_string(), json!("."));
tracing::debug!("[Streaming] Added default path: \".\"");
}
}
}
"glob" => {
if let Some(desc) = obj.remove("description") {
if !obj.contains_key("pattern") {
obj.insert("pattern".to_string(), desc);
tracing::debug!("[Streaming] Remapped Glob: description → pattern");
}
}
if let Some(query) = obj.remove("query") {
if !obj.contains_key("pattern") {
obj.insert("pattern".to_string(), query);
tracing::debug!("[Streaming] Remapped Glob: query → pattern");
}
}
if !obj.contains_key("path") {
if let Some(paths) = obj.remove("paths") {
let path_str = if let Some(arr) = paths.as_array() {
arr.first()
.and_then(|v| v.as_str())
.unwrap_or(".")
.to_string()
} else if let Some(s) = paths.as_str() {
s.to_string()
} else {
".".to_string()
};
obj.insert("path".to_string(), serde_json::json!(path_str));
tracing::debug!(
"[Streaming] Remapped Glob: paths → path(\"{}\")",
path_str
);
} else {
obj.insert("path".to_string(), json!("."));
tracing::debug!("[Streaming] Added default path: \".\"");
}
}
}
"read" => {
if let Some(path) = obj.remove("path") {
if !obj.contains_key("file_path") {
obj.insert("file_path".to_string(), path);
tracing::debug!("[Streaming] Remapped Read: path → file_path");
}
}
}
"ls" => {
if !obj.contains_key("path") {
obj.insert("path".to_string(), json!("."));
tracing::debug!("[Streaming] Remapped LS: default path → \".\"");
}
}
other => {
let mut path_to_inject = None;
if !obj.contains_key("path") {
if let Some(paths) = obj.get("paths").and_then(|v| v.as_array()) {
if paths.len() == 1 {
if let Some(p) = paths[0].as_str() {
path_to_inject = Some(p.to_string());
}
}
}
}
if let Some(path) = path_to_inject {
obj.insert("path".to_string(), json!(path));
tracing::debug!(
"[Streaming] Probabilistic fix for tool '{}': paths[0] → path(\"{}\")",
other,
path
);
}
tracing::debug!(
"[Streaming] Unmapped tool call processed via generic rules: {} (keys: {:?})",
other,
obj.keys()
);
}
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum BlockType {
None,
Text,
Thinking,
Function,
}
pub struct SignatureManager {
pending: Option<String>,
}
impl SignatureManager {
pub fn new() -> Self {
Self { pending: None }
}
pub fn store(&mut self, signature: Option<String>) {
if signature.is_some() {
self.pending = signature;
}
}
pub fn consume(&mut self) -> Option<String> {
self.pending.take()
}
pub fn has_pending(&self) -> bool {
self.pending.is_some()
}
}
pub struct StreamingState {
block_type: BlockType,
pub block_index: usize,
pub message_start_sent: bool,
pub message_stop_sent: bool,
used_tool: bool,
signatures: SignatureManager,
trailing_signature: Option<String>,
pub web_search_query: Option<String>,
pub grounding_chunks: Option<Vec<serde_json::Value>>,
pub model_name: Option<String>,
pub session_id: Option<String>,
pub scaling_enabled: bool,
pub context_limit: u32,
pub mcp_xml_buffer: String,
pub in_mcp_xml: bool,
pub estimated_prompt_tokens: Option<u32>,
pub has_thinking: bool,
pub has_content: bool,
pub message_count: usize,
pub client_adapter: Option<std::sync::Arc<dyn ClientAdapter>>,
}
impl StreamingState {
pub fn new() -> Self {
Self {
block_type: BlockType::None,
block_index: 0,
message_start_sent: false,
message_stop_sent: false,
used_tool: false,
signatures: SignatureManager::new(),
trailing_signature: None,
web_search_query: None,
grounding_chunks: None,
model_name: None,
session_id: None,
scaling_enabled: false,
context_limit: 1_048_576,
mcp_xml_buffer: String::new(),
in_mcp_xml: false,
estimated_prompt_tokens: None,
has_thinking: false,
has_content: false,
message_count: 0,
client_adapter: None,
}
}
pub fn set_client_adapter(&mut self, adapter: Option<std::sync::Arc<dyn ClientAdapter>>) {
self.client_adapter = adapter;
}
pub fn emit(&self, event_type: &str, data: serde_json::Value) -> Bytes {
let sse = format!(
"event: {}\ndata: {}\n\n",
event_type,
serde_json::to_string(&data).unwrap_or_default()
);
Bytes::from(sse)
}
pub fn emit_message_start(&mut self, raw_json: &serde_json::Value) -> Bytes {
if self.message_start_sent {
return Bytes::new();
}
let usage = raw_json
.get("usageMetadata")
.and_then(|u| serde_json::from_value::<UsageMetadata>(u.clone()).ok())
.map(|u| to_claude_usage(&u, self.scaling_enabled, self.context_limit));
let mut message = json!({
"id": raw_json.get("responseId")
.and_then(|v| v.as_str())
.unwrap_or("msg_unknown"),
"type": "message",
"role": "assistant",
"content": [],
"model": raw_json.get("modelVersion")
.and_then(|v| v.as_str())
.unwrap_or(""),
"stop_reason": null,
"stop_sequence": null,
});
if let Some(m) = raw_json.get("modelVersion").and_then(|v| v.as_str()) {
self.model_name = Some(m.to_string());
}
if let Some(u) = usage {
message["usage"] = json!(u);
}
let result = self.emit(
"message_start",
json!({
"type": "message_start",
"message": message
}),
);
self.message_start_sent = true;
result
}
pub fn start_block(
&mut self,
block_type: BlockType,
content_block: serde_json::Value,
) -> Vec<Bytes> {
let mut chunks = Vec::new();
if self.block_type != BlockType::None {
chunks.extend(self.end_block());
}
chunks.push(self.emit(
"content_block_start",
json!({
"type": "content_block_start",
"index": self.block_index,
"content_block": content_block
}),
));
self.block_type = block_type;
chunks
}
pub fn end_block(&mut self) -> Vec<Bytes> {
if self.block_type == BlockType::None {
return vec![];
}
let mut chunks = Vec::new();
if self.block_type == BlockType::Thinking && self.signatures.has_pending() {
if let Some(signature) = self.signatures.consume() {
chunks.push(self.emit_delta("signature_delta", json!({ "signature": signature })));
}
}
chunks.push(self.emit(
"content_block_stop",
json!({
"type": "content_block_stop",
"index": self.block_index
}),
));
self.block_index += 1;
self.block_type = BlockType::None;
chunks
}
pub fn emit_delta(&self, delta_type: &str, delta_content: serde_json::Value) -> Bytes {
let mut delta = json!({ "type": delta_type });
if let serde_json::Value::Object(map) = delta_content {
for (k, v) in map {
delta[k] = v;
}
}
self.emit(
"content_block_delta",
json!({
"type": "content_block_delta",
"index": self.block_index,
"delta": delta
}),
)
}
pub fn emit_finish(
&mut self,
finish_reason: Option<&str>,
usage_metadata: Option<&UsageMetadata>,
) -> Vec<Bytes> {
let mut chunks = Vec::new();
chunks.extend(self.end_block());
if let Some(signature) = self.trailing_signature.take() {
tracing::info!(
"[Streaming] Captured trailing signature (len: {}), caching for session.",
signature.len()
);
self.signatures.store(Some(signature));
}
if self.web_search_query.is_some() || self.grounding_chunks.is_some() {
let mut grounding_text = String::new();
if let Some(query) = &self.web_search_query {
if !query.is_empty() {
grounding_text.push_str("\n\n---\n**🔍 Searched for you:** ");
grounding_text.push_str(query);
}
}
if let Some(chunks) = &self.grounding_chunks {
let mut links = Vec::new();
for (i, chunk) in chunks.iter().enumerate() {
if let Some(web) = chunk.get("web") {
let title = web
.get("title")
.and_then(|v| v.as_str())
.unwrap_or("Web Source");
let uri = web.get("uri").and_then(|v| v.as_str()).unwrap_or("#");
links.push(format!("[{}] [{}]({})", i + 1, title, uri));
}
}
if !links.is_empty() {
grounding_text.push_str("\n\n**🌐 Source Citations:**\n");
grounding_text.push_str(&links.join("\n"));
}
}
if !grounding_text.is_empty() {
chunks.push(self.emit(
"content_block_start",
json!({
"type": "content_block_start",
"index": self.block_index,
"content_block": { "type": "text", "text": "" }
}),
));
chunks.push(self.emit_delta("text_delta", json!({ "text": grounding_text })));
chunks.push(self.emit(
"content_block_stop",
json!({ "type": "content_block_stop", "index": self.block_index }),
));
self.block_index += 1;
}
}
let stop_reason = if self.used_tool {
"tool_use"
} else if finish_reason == Some("MAX_TOKENS") {
"max_tokens"
} else {
"end_turn"
};
let usage = usage_metadata
.map(|u| {
if let (Some(estimated), Some(actual)) =
(self.estimated_prompt_tokens, u.prompt_token_count)
{
if estimated > 0 && actual > 0 {
get_calibrator().record(estimated, actual);
tracing::debug!(
"[Calibrator] Recorded: estimated={}, actual={}, ratio={:.2}x",
estimated,
actual,
actual as f64 / estimated as f64
);
}
}
to_claude_usage(u, self.scaling_enabled, self.context_limit)
})
.unwrap_or(Usage {
input_tokens: 0,
output_tokens: 0,
cache_read_input_tokens: None,
cache_creation_input_tokens: None,
server_tool_use: None,
});
chunks.push(self.emit(
"message_delta",
json!({
"type": "message_delta",
"delta": { "stop_reason": stop_reason, "stop_sequence": null },
"usage": usage
}),
));
if !self.message_stop_sent {
chunks.push(Bytes::from(
"event: message_stop\ndata: {\"type\":\"message_stop\"}\n\n",
));
self.message_stop_sent = true;
}
chunks
}
pub fn mark_tool_used(&mut self) {
self.used_tool = true;
}
pub fn current_block_type(&self) -> BlockType {
self.block_type
}
pub fn current_block_index(&self) -> usize {
self.block_index
}
pub fn store_signature(&mut self, signature: Option<String>) {
self.signatures.store(signature);
}
pub fn set_trailing_signature(&mut self, signature: Option<String>) {
self.trailing_signature = signature;
}
pub fn has_trailing_signature(&self) -> bool {
self.trailing_signature.is_some()
}
}
pub struct PartProcessor<'a> {
state: &'a mut StreamingState,
}
impl<'a> PartProcessor<'a> {
pub fn new(state: &'a mut StreamingState) -> Self {
Self { state }
}
pub fn process(&mut self, part: &GeminiPart) -> Vec<Bytes> {
let mut chunks = Vec::new();
let signature = part.thought_signature.as_ref().map(|sig| {
use base64::Engine;
match base64::engine::general_purpose::STANDARD.decode(sig) {
Ok(decoded_bytes) => match String::from_utf8(decoded_bytes) {
Ok(decoded_str) => {
tracing::debug!(
"[Streaming] Decoded base64 signature (len {} -> {})",
sig.len(),
decoded_str.len()
);
decoded_str
}
Err(_) => sig.clone(),
},
Err(_) => sig.clone(),
}
});
if let Some(fc) = &part.function_call {
if self.state.has_trailing_signature() {
chunks.extend(self.state.end_block());
if let Some(trailing_sig) = self.state.trailing_signature.take() {
chunks.push(self.state.emit(
"content_block_start",
json!({
"type": "content_block_start",
"index": self.state.current_block_index(),
"content_block": { "type": "thinking", "thinking": "" }
}),
));
chunks.push(
self.state
.emit_delta("thinking_delta", json!({ "thinking": "" })),
);
chunks.push(
self.state
.emit_delta("signature_delta", json!({ "signature": trailing_sig })),
);
chunks.extend(self.state.end_block());
}
}
chunks.extend(self.process_function_call(fc, signature));
self.state.has_content = true;
return chunks;
}
if let Some(text) = &part.text {
if part.thought.unwrap_or(false) {
chunks.extend(self.process_thinking(text, signature));
} else {
chunks.extend(self.process_text(text, signature));
}
}
if let Some(img) = &part.inline_data {
let mime_type = &img.mime_type;
let data = &img.data;
if !data.is_empty() {
let markdown_img = format!("", mime_type, data);
chunks.extend(self.process_text(&markdown_img, None));
}
}
chunks
}
fn process_thinking(&mut self, text: &str, signature: Option<String>) -> Vec<Bytes> {
let mut chunks = Vec::new();
if self.state.has_trailing_signature() {
chunks.extend(self.state.end_block());
if let Some(trailing_sig) = self.state.trailing_signature.take() {
chunks.push(self.state.emit(
"content_block_start",
json!({
"type": "content_block_start",
"index": self.state.current_block_index(),
"content_block": { "type": "thinking", "thinking": "" }
}),
));
chunks.push(
self.state
.emit_delta("thinking_delta", json!({ "thinking": "" })),
);
chunks.push(
self.state
.emit_delta("signature_delta", json!({ "signature": trailing_sig })),
);
chunks.extend(self.state.end_block());
}
}
if self.state.current_block_type() != BlockType::Thinking {
chunks.extend(self.state.start_block(
BlockType::Thinking,
json!({ "type": "thinking", "thinking": "" }),
));
}
self.state.has_thinking = true;
if !text.is_empty() {
chunks.push(
self.state
.emit_delta("thinking_delta", json!({ "thinking": text })),
);
}
let use_fifo = self
.state
.client_adapter
.as_ref()
.map(|a| a.signature_buffer_strategy() == SignatureBufferStrategy::Fifo)
.unwrap_or(false);
if let Some(ref sig) = signature {
if let Some(model) = &self.state.model_name {
SignatureCache::global().cache_thinking_family(sig.clone(), model.clone());
}
if let Some(session_id) = &self.state.session_id {
SignatureCache::global().cache_session_signature(
session_id,
sig.clone(),
self.state.message_count,
);
tracing::debug!(
"[Claude-SSE] Cached signature to session {} (length: {}) [FIFO: {}]",
session_id,
sig.len(),
use_fifo
);
}
tracing::debug!(
"[Claude-SSE] Captured thought_signature from thinking block (length: {})",
sig.len()
);
}
self.state.store_signature(signature);
chunks
}
fn process_text(&mut self, text: &str, signature: Option<String>) -> Vec<Bytes> {
let mut chunks = Vec::new();
if text.is_empty() {
if signature.is_some() {
self.state.set_trailing_signature(signature);
}
return chunks;
}
self.state.has_content = true;
if self.state.has_trailing_signature() {
chunks.extend(self.state.end_block());
if let Some(trailing_sig) = self.state.trailing_signature.take() {
chunks.push(self.state.emit(
"content_block_start",
json!({
"type": "content_block_start",
"index": self.state.current_block_index(),
"content_block": { "type": "thinking", "thinking": "" }
}),
));
chunks.push(
self.state
.emit_delta("thinking_delta", json!({ "thinking": "" })),
);
chunks.push(
self.state
.emit_delta("signature_delta", json!({ "signature": trailing_sig })),
);
chunks.extend(self.state.end_block());
}
}
if signature.is_some() {
self.state.store_signature(signature);
chunks.extend(
self.state
.start_block(BlockType::Text, json!({ "type": "text", "text": "" })),
);
chunks.push(self.state.emit_delta("text_delta", json!({ "text": text })));
chunks.extend(self.state.end_block());
return chunks;
}
if text.contains("<mcp__") || self.state.in_mcp_xml {
self.state.in_mcp_xml = true;
self.state.mcp_xml_buffer.push_str(text);
if self.state.mcp_xml_buffer.contains("</mcp__")
&& self.state.mcp_xml_buffer.contains('>')
{
let buffer = self.state.mcp_xml_buffer.clone();
if let Some(start_idx) = buffer.find("<mcp__") {
if let Some(tag_end_idx) = buffer[start_idx..].find('>') {
let actual_tag_end = start_idx + tag_end_idx;
let tool_name = &buffer[start_idx + 1..actual_tag_end];
let end_tag = format!("</{}>", tool_name);
if let Some(close_idx) = buffer.find(&end_tag) {
let input_str = &buffer[actual_tag_end + 1..close_idx];
let input_json: serde_json::Value =
serde_json::from_str(input_str.trim())
.unwrap_or_else(|_| json!({ "input": input_str.trim() }));
let fc = FunctionCall {
name: tool_name.to_string(),
args: Some(input_json),
id: Some(format!("{}-xml", tool_name)),
};
let tool_chunks = self.process_function_call(&fc, None);
self.state.mcp_xml_buffer.clear();
self.state.in_mcp_xml = false;
if start_idx > 0 {
let prefix_text = &buffer[..start_idx];
if self.state.current_block_type() != BlockType::Text {
chunks.extend(self.state.start_block(
BlockType::Text,
json!({ "type": "text", "text": "" }),
));
}
chunks.push(
self.state
.emit_delta("text_delta", json!({ "text": prefix_text })),
);
}
chunks.extend(tool_chunks);
let suffix = &buffer[close_idx + end_tag.len()..];
if !suffix.is_empty() {
chunks.extend(self.process_text(suffix, None));
}
return chunks;
}
}
}
}
return vec![];
}
if self.state.current_block_type() != BlockType::Text {
chunks.extend(
self.state
.start_block(BlockType::Text, json!({ "type": "text", "text": "" })),
);
}
chunks.push(self.state.emit_delta("text_delta", json!({ "text": text })));
chunks
}
fn process_function_call(
&mut self,
fc: &FunctionCall,
signature: Option<String>,
) -> Vec<Bytes> {
let mut chunks = Vec::new();
self.state.mark_tool_used();
let tool_id = fc.id.clone().unwrap_or_else(|| {
format!(
"{}-{}",
fc.name,
crate::proxy::common::utils::generate_random_id()
)
});
let mut tool_name = fc.name.clone();
if tool_name.to_lowercase() == "search" {
tool_name = "grep".to_string();
tracing::debug!("[Streaming] Normalizing tool name: Search → grep");
}
let mut tool_use = json!({
"type": "tool_use",
"id": tool_id,
"name": tool_name,
"input": {}
});
if let Some(ref sig) = signature {
tool_use["signature"] = json!(sig);
SignatureCache::global().cache_tool_signature(&tool_id, sig.clone());
if let Some(session_id) = &self.state.session_id {
SignatureCache::global().cache_session_signature(
session_id,
sig.clone(),
self.state.message_count,
);
}
tracing::debug!(
"[Claude-SSE] Captured thought_signature for function call (length: {})",
sig.len()
);
}
chunks.extend(self.state.start_block(BlockType::Function, tool_use));
if let Some(args) = &fc.args {
let mut remapped_args = args.clone();
let tool_name_title = fc.name.clone();
let mut final_tool_name = tool_name_title;
if final_tool_name.to_lowercase() == "search" {
final_tool_name = "Grep".to_string();
}
remap_function_call_args(&final_tool_name, &mut remapped_args);
let json_str =
serde_json::to_string(&remapped_args).unwrap_or_else(|_| "{}".to_string());
chunks.push(
self.state
.emit_delta("input_json_delta", json!({ "partial_json": json_str })),
);
}
chunks.extend(self.state.end_block());
chunks
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_signature_manager() {
let mut mgr = SignatureManager::new();
assert!(!mgr.has_pending());
mgr.store(Some("sig123".to_string()));
assert!(mgr.has_pending());
let sig = mgr.consume();
assert_eq!(sig, Some("sig123".to_string()));
assert!(!mgr.has_pending());
}
#[test]
fn test_streaming_state_emit() {
let state = StreamingState::new();
let chunk = state.emit("test_event", json!({"foo": "bar"}));
let s = String::from_utf8(chunk.to_vec()).unwrap();
assert!(s.contains("event: test_event"));
assert!(s.contains("\"foo\":\"bar\""));
}
#[test]
fn test_process_function_call_deltas() {
let mut state = StreamingState::new();
let mut processor = PartProcessor::new(&mut state);
let fc = FunctionCall {
name: "test_tool".to_string(),
args: Some(json!({"arg": "value"})),
id: Some("call_123".to_string()),
};
let part = GeminiPart {
text: None,
function_call: Some(fc),
inline_data: None,
thought: None,
thought_signature: None,
function_response: None,
};
let chunks = processor.process(&part);
let output = chunks
.iter()
.map(|b| String::from_utf8(b.to_vec()).unwrap())
.collect::<Vec<_>>()
.join("");
assert!(output.contains(r#""type":"content_block_start""#));
assert!(output.contains(r#""name":"test_tool""#));
assert!(output.contains(r#""input":{}"#));
assert!(output.contains(r#""type":"content_block_delta""#));
assert!(output.contains(r#""type":"input_json_delta""#));
assert!(output.contains(r#"partial_json":"{\"arg\":\"value\"}"#));
assert!(output.contains(r#""type":"content_block_stop""#));
}
}