crabtalk 0.0.22

Crabtalk library
Documentation
//! Conversation operations: send/stream, kill, and ask/tool reply
//! routing. Pure-runtime ops live on `Runtime<C>` directly.

use crate::system::CrabTalk;
use anyhow::Result;
use crabllm_core::Provider;
use futures_util::{StreamExt, pin_mut};
use std::sync::Arc;
use wcore::AgentEvent;
use wcore::protocol::message::*;

impl<P: Provider + 'static> CrabTalk<P> {
    pub(crate) async fn send(&self, req: SendMsg) -> Result<SendResponse> {
        let rt: Arc<_> = self.runtime.read().await.clone();
        let sender = req.sender.as_deref().unwrap_or("");
        let created_by = if sender.is_empty() { "user" } else { sender };
        let conversation_id = rt
            .get_or_create_conversation(&req.agent, created_by)
            .await?;
        let tool_choice = req
            .tool_choice
            .map(|s| wcore::model::ToolChoice::from(s.as_str()));
        let response = rt
            .send_to(conversation_id, &req.content, sender, tool_choice)
            .await?;
        Ok(SendResponse {
            agent: req.agent,
            content: response.final_response.unwrap_or_default(),
            model: response.model,
            usage: Some(sum_usage(&response.steps)),
        })
    }

    pub(crate) fn stream<'a>(
        &'a self,
        req: StreamMsg,
    ) -> impl futures_core::Stream<Item = Result<StreamEvent>> + Send + 'a {
        let runtime = self.runtime.clone();
        let tool_hook = self.tool_hook.clone();
        let agent = req.agent;
        let content = req.content;
        let sender = req.sender.unwrap_or_default();
        let guest = req.guest.unwrap_or_default();
        let tool_choice = req
            .tool_choice
            .map(|s| wcore::model::ToolChoice::from(s.as_str()));
        async_stream::try_stream! {
            let rt: Arc<_> = runtime.read().await.clone();
            let created_by = if sender.is_empty() { "user".into() } else { sender.clone() };
            let conversation_id = rt.get_or_create_conversation(&agent, created_by.as_str()).await?;
            // Register this conversation as having a stream listener so the
            // client-tools hook will forward dispatches here. The guard
            // unregisters on any exit path — stream end, early return on
            // Done, or consumer dropping the stream — and fails any
            // pending forwarded calls so they don't sit until timeout.
            tool_hook.register_listener(conversation_id);
            let _listener_guard = ListenerGuard::new(tool_hook.clone(), conversation_id);

            let responding_agent = if guest.is_empty() { agent.clone() } else { guest.clone() };
            yield StreamEvent { event: Some(stream_event::Event::Start(StreamStart { agent: responding_agent.clone() })) };

            let stream: std::pin::Pin<Box<dyn futures_core::Stream<Item = wcore::AgentEvent> + Send + '_>> = if guest.is_empty() {
                Box::pin(rt.stream_to(conversation_id, &content, &sender, tool_choice))
            } else {
                Box::pin(rt.guest_stream_to(conversation_id, &content, &sender, &guest))
            };
            pin_mut!(stream);
            while let Some(event) = stream.next().await {
                match event {
                    AgentEvent::TextStart => {
                        yield StreamEvent { event: Some(stream_event::Event::TextStart(TextStartEvent { agent: responding_agent.clone() })) };
                    }
                    AgentEvent::TextDelta(text) => {
                        yield StreamEvent { event: Some(stream_event::Event::Chunk(StreamChunk { content: text })) };
                    }
                    AgentEvent::TextEnd => {
                        yield StreamEvent { event: Some(stream_event::Event::TextEnd(TextEndEvent { agent: responding_agent.clone() })) };
                    }
                    AgentEvent::ThinkingStart => {
                        yield StreamEvent { event: Some(stream_event::Event::ThinkingStart(ThinkingStartEvent { agent: responding_agent.clone() })) };
                    }
                    AgentEvent::ThinkingDelta(text) => {
                        yield StreamEvent { event: Some(stream_event::Event::Thinking(StreamThinking { content: text })) };
                    }
                    AgentEvent::ThinkingEnd => {
                        yield StreamEvent { event: Some(stream_event::Event::ThinkingEnd(ThinkingEndEvent { agent: responding_agent.clone() })) };
                    }
                    AgentEvent::ToolCallsBegin(calls) => {
                        yield StreamEvent { event: Some(stream_event::Event::ToolStart(ToolStartEvent {
                            calls: calls.into_iter().map(|c| ToolCallInfo {
                                name: c.function.name.to_string(),
                                arguments: String::new(),
                            }).collect(),
                        })) };
                    }
                    AgentEvent::ToolCallsStart(calls) => {
                        let forwards: Vec<ToolCallForwardEvent> = calls
                            .iter()
                            .filter(|c| tool_hook.is_client_tool(&c.function.name))
                            .map(|c| ToolCallForwardEvent {
                                call_id: c.id.to_string(),
                                name: c.function.name.to_string(),
                                arguments: c.function.arguments.clone(),
                                conversation_id,
                            })
                            .collect();

                        yield StreamEvent { event: Some(stream_event::Event::ToolStart(ToolStartEvent {
                            calls: calls.into_iter().map(|c| ToolCallInfo {
                                name: c.function.name.to_string(),
                                arguments: c.function.arguments,
                            }).collect(),
                        })) };

                        for fwd in forwards {
                            yield StreamEvent { event: Some(stream_event::Event::ToolCallForward(fwd)) };
                        }
                    }
                    AgentEvent::ToolResult { call_id, output, duration_ms } => {
                        let is_error = output.is_err();
                        let output = match output { Ok(s) | Err(s) => s };
                        yield StreamEvent { event: Some(stream_event::Event::ToolResult(ToolResultEvent { call_id: call_id.to_string(), output, duration_ms, is_error })) };
                    }
                    AgentEvent::ToolCallsComplete => {
                        yield StreamEvent { event: Some(stream_event::Event::ToolsComplete(ToolsCompleteEvent {})) };
                    }
                    AgentEvent::ContextUsage { ref usage } => {
                        yield StreamEvent { event: Some(stream_event::Event::ContextUsage(ContextUsageEvent { usage: Some(usage_to_proto(usage)) })) };
                    }
                    AgentEvent::UserSteered { ref content } => {
                        yield StreamEvent { event: Some(stream_event::Event::UserSteered(UserSteeredEvent { content: content.clone() })) };
                    }
                    AgentEvent::Done(resp) => {
                        let error = if let wcore::AgentStopReason::Error(ref e) = resp.stop_reason {
                            e.clone()
                        } else {
                            String::new()
                        };
                        yield StreamEvent { event: Some(stream_event::Event::End(StreamEnd {
                            agent: responding_agent.clone(),
                            error,
                            model: resp.model,
                            usage: Some(sum_usage(&resp.steps)),
                        })) };
                        return;
                    }
                }
            }
            yield StreamEvent { event: Some(stream_event::Event::End(StreamEnd {
                agent: responding_agent.clone(),
                error: String::new(),
                model: String::new(),
                usage: None,
            })) };
        }
    }

    pub(crate) async fn kill_conversation(&self, agent: &str, sender: &str) -> Result<bool> {
        let rt = self.runtime.read().await.clone();
        let Some(conversation_id) = rt.conversation_id(agent, sender).await else {
            return Ok(false);
        };
        Ok(rt.close(conversation_id).await)
    }

    pub(crate) async fn reply_to_tool(
        &self,
        conversation_id: u64,
        call_id: &str,
        output: String,
        is_error: bool,
    ) -> Result<()> {
        // No retry needed: `try_resolve` accepts replies that arrive
        // before the agent's dispatch parks (stashed as `EarlyReply`),
        // so the dispatch/reply race is handled symmetrically inside
        // the hook rather than via sleep-and-pray here.
        if self
            .tool_hook
            .try_resolve(conversation_id, call_id, output, is_error)
        {
            Ok(())
        } else {
            anyhow::bail!("duplicate reply for call_id '{call_id}'")
        }
    }
}

/// RAII guard that synchronously unregisters a stream's client-tool
/// listener and drains pending forwarded calls on drop.
struct ListenerGuard {
    hook: Arc<crate::hooks::tool::ToolHook>,
    conv_id: u64,
}

impl ListenerGuard {
    fn new(hook: Arc<crate::hooks::tool::ToolHook>, conv_id: u64) -> Self {
        Self { hook, conv_id }
    }
}

impl Drop for ListenerGuard {
    fn drop(&mut self) {
        self.hook.unregister_listener(self.conv_id);
    }
}

pub(super) fn sum_usage(steps: &[wcore::AgentStep]) -> TokenUsage {
    let mut prompt = 0u32;
    let mut completion = 0u32;
    let mut total = 0u32;
    let mut cache_hit = 0u32;
    let mut cache_miss = 0u32;
    let mut reasoning = 0u32;
    let mut has_cache_hit = false;
    let mut has_cache_miss = false;
    let mut has_reasoning = false;

    for step in steps {
        let u = &step.usage;
        prompt += u.prompt_tokens;
        completion += u.completion_tokens;
        total += u.total_tokens;
        if let Some(v) = u.prompt_cache_hit_tokens {
            cache_hit += v;
            has_cache_hit = true;
        }
        if let Some(v) = u.prompt_cache_miss_tokens {
            cache_miss += v;
            has_cache_miss = true;
        }
        if let Some(ref d) = u.completion_tokens_details
            && let Some(v) = d.reasoning_tokens
        {
            reasoning += v;
            has_reasoning = true;
        }
    }

    TokenUsage {
        prompt_tokens: prompt,
        completion_tokens: completion,
        total_tokens: total,
        cache_hit_tokens: has_cache_hit.then_some(cache_hit),
        cache_miss_tokens: has_cache_miss.then_some(cache_miss),
        reasoning_tokens: has_reasoning.then_some(reasoning),
    }
}

fn usage_to_proto(u: &crabllm_core::Usage) -> TokenUsage {
    TokenUsage {
        prompt_tokens: u.prompt_tokens,
        completion_tokens: u.completion_tokens,
        total_tokens: u.total_tokens,
        cache_hit_tokens: u.prompt_cache_hit_tokens,
        cache_miss_tokens: u.prompt_cache_miss_tokens,
        reasoning_tokens: u
            .completion_tokens_details
            .as_ref()
            .and_then(|d| d.reasoning_tokens),
    }
}