awaken-contract 0.2.1

Core types, traits, and state model for the Awaken AI agent runtime
Documentation
//! LLM executor trait and tool execution strategy.

use super::content::ContentBlock;
use super::inference::{InferenceOverride, StreamResult};
use super::message::Message;
use super::tool::ToolDescriptor;
use async_trait::async_trait;
use thiserror::Error;

/// A provider-neutral LLM inference request.
#[derive(Debug, Clone)]
pub struct InferenceRequest {
    /// Effective upstream model name sent to the resolved provider executor.
    pub upstream_model: String,
    /// Messages to send.
    pub messages: Vec<Message>,
    /// Available tools.
    pub tools: Vec<ToolDescriptor>,
    /// System prompt content blocks. Empty means no system prompt.
    pub system: Vec<ContentBlock>,
    /// Per-inference overrides that remain after runtime routing is applied
    /// (temperature, max_tokens, fallback upstream models, etc).
    pub overrides: Option<InferenceOverride>,
    /// Whether to apply prompt cache hints (e.g. `CacheControl::Ephemeral`) to system messages.
    pub enable_prompt_cache: bool,
}

/// Errors from LLM inference.
#[derive(Debug, Error)]
pub enum InferenceExecutionError {
    #[error("provider error: {0}")]
    Provider(String),
    #[error("rate limited: {0}")]
    RateLimited(String),
    #[error("timeout: {0}")]
    Timeout(String),
    #[error("cancelled")]
    Cancelled,
}

/// A token-level streaming event from the LLM.
#[derive(Debug, Clone)]
pub enum LlmStreamEvent {
    /// Incremental text content.
    TextDelta(String),
    /// Incremental reasoning/thinking content.
    ReasoningDelta(String),
    /// A tool use block started.
    ToolCallStart { id: String, name: String },
    /// Incremental tool call argument JSON.
    ToolCallDelta { id: String, args_delta: String },
    /// A content block finished.
    ContentBlockStop,
    /// Token usage data (typically sent once at the end).
    Usage(super::inference::TokenUsage),
    /// Stop reason (end of stream).
    Stop(super::inference::StopReason),
}

/// A boxed stream of `LlmStreamEvent`s.
///
/// Implementors wrap their provider-specific streaming response into this type.
/// The loop runner consumes events, emits deltas via `EventSink`, and collects
/// the final `StreamResult`.
pub type InferenceStream = std::pin::Pin<
    Box<dyn futures::Stream<Item = Result<LlmStreamEvent, InferenceExecutionError>> + Send>,
>;

/// Abstraction over LLM inference backends.
///
/// Providers implement `execute` (collected) and optionally `execute_stream` (streaming).
/// The loop runner prefers `execute_stream` when available.
#[async_trait]
pub trait LlmExecutor: Send + Sync {
    /// Execute a chat completion and return the collected result.
    async fn execute(
        &self,
        request: InferenceRequest,
    ) -> Result<StreamResult, InferenceExecutionError>;

    /// Execute a chat completion as a token stream.
    ///
    /// Default implementation calls `execute()` and wraps the result as a single-event stream.
    /// Override to provide true token-level streaming from the LLM provider.
    fn execute_stream(
        &self,
        request: InferenceRequest,
    ) -> std::pin::Pin<
        Box<
            dyn std::future::Future<Output = Result<InferenceStream, InferenceExecutionError>>
                + Send
                + '_,
        >,
    > {
        Box::pin(async move {
            let result = self.execute(request).await?;
            let events = collected_to_stream_events(result);
            Ok(Box::pin(futures::stream::iter(events)) as InferenceStream)
        })
    }

    /// Provider name for logging/debugging.
    fn name(&self) -> &str;
}

/// Convert a collected `StreamResult` into a sequence of `LlmStreamEvent`s.
pub fn collected_to_stream_events(
    result: StreamResult,
) -> Vec<Result<LlmStreamEvent, InferenceExecutionError>> {
    use super::content::ContentBlock;
    let mut events = Vec::new();

    // Emit text/thinking deltas from content blocks
    for block in &result.content {
        match block {
            ContentBlock::Text { text } if !text.is_empty() => {
                events.push(Ok(LlmStreamEvent::TextDelta(text.clone())));
            }
            ContentBlock::Thinking { thinking } if !thinking.is_empty() => {
                events.push(Ok(LlmStreamEvent::ReasoningDelta(thinking.clone())));
            }
            _ => {}
        }
    }

    // Emit tool calls
    for call in &result.tool_calls {
        events.push(Ok(LlmStreamEvent::ToolCallStart {
            id: call.id.clone(),
            name: call.name.clone(),
        }));
        let args = serde_json::to_string(&call.arguments).unwrap_or_default();
        if !args.is_empty() {
            events.push(Ok(LlmStreamEvent::ToolCallDelta {
                id: call.id.clone(),
                args_delta: args,
            }));
        }
    }

    // Emit usage
    if let Some(usage) = result.usage {
        events.push(Ok(LlmStreamEvent::Usage(usage)));
    }

    // Emit stop reason
    if let Some(stop) = result.stop_reason {
        events.push(Ok(LlmStreamEvent::Stop(stop)));
    }

    events
}

/// Tool execution strategy.
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub enum ToolExecutionMode {
    /// Execute tool calls one at a time.
    #[default]
    Sequential,
    /// Execute all tool calls concurrently, batch approval gate.
    ParallelBatchApproval,
    /// Execute all tool calls concurrently, streaming results.
    ParallelStreaming,
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::contract::inference::{StopReason, TokenUsage};
    use crate::contract::message::ToolCall;
    use crate::contract::tool::ToolDescriptor;
    use serde_json::json;

    /// A mock LLM executor for testing.
    struct MockLlm {
        response_text: String,
        tool_calls: Vec<ToolCall>,
    }

    #[async_trait]
    impl LlmExecutor for MockLlm {
        async fn execute(
            &self,
            _request: InferenceRequest,
        ) -> Result<StreamResult, InferenceExecutionError> {
            Ok(StreamResult {
                content: if self.response_text.is_empty() {
                    vec![]
                } else {
                    vec![ContentBlock::text(self.response_text.clone())]
                },
                tool_calls: self.tool_calls.clone(),
                usage: Some(TokenUsage {
                    prompt_tokens: Some(100),
                    completion_tokens: Some(50),
                    total_tokens: Some(150),
                    ..Default::default()
                }),
                stop_reason: if self.tool_calls.is_empty() {
                    Some(StopReason::EndTurn)
                } else {
                    Some(StopReason::ToolUse)
                },
                has_incomplete_tool_calls: false,
            })
        }

        fn name(&self) -> &str {
            "mock"
        }
    }

    #[tokio::test]
    async fn mock_llm_returns_text() {
        let llm = MockLlm {
            response_text: "Hello!".into(),
            tool_calls: vec![],
        };
        let request = InferenceRequest {
            upstream_model: "test-model".into(),
            messages: vec![Message::user("hi")],
            tools: vec![],
            system: vec![],
            overrides: None,
            enable_prompt_cache: false,
        };
        let result = llm.execute(request).await.unwrap();
        assert_eq!(result.text(), "Hello!");
        assert!(!result.needs_tools());
        assert_eq!(result.stop_reason, Some(StopReason::EndTurn));
    }

    #[tokio::test]
    async fn mock_llm_returns_tool_calls() {
        let llm = MockLlm {
            response_text: String::new(),
            tool_calls: vec![ToolCall::new("c1", "search", json!({"q": "rust"}))],
        };
        let request = InferenceRequest {
            upstream_model: "test-model".into(),
            messages: vec![Message::user("search for rust")],
            tools: vec![ToolDescriptor::new("search", "search", "Web search")],
            system: vec![ContentBlock::text("You are helpful.")],
            overrides: None,
            enable_prompt_cache: false,
        };
        let result = llm.execute(request).await.unwrap();
        assert!(result.needs_tools());
        assert_eq!(result.tool_calls.len(), 1);
        assert_eq!(result.tool_calls[0].name, "search");
        assert_eq!(result.stop_reason, Some(StopReason::ToolUse));
    }

    #[tokio::test]
    async fn mock_llm_with_overrides() {
        let llm = MockLlm {
            response_text: "ok".into(),
            tool_calls: vec![],
        };
        let request = InferenceRequest {
            upstream_model: "base-model".into(),
            messages: vec![],
            tools: vec![],
            system: vec![],
            overrides: Some(InferenceOverride {
                temperature: Some(0.7),
                ..Default::default()
            }),
            enable_prompt_cache: false,
        };
        let result = llm.execute(request).await.unwrap();
        assert_eq!(result.text(), "ok");
    }

    #[test]
    fn llm_executor_name_is_exposed() {
        let llm = MockLlm {
            response_text: String::new(),
            tool_calls: vec![],
        };

        assert_eq!(llm.name(), "mock");
    }

    #[test]
    fn tool_execution_mode_default_is_sequential() {
        assert_eq!(ToolExecutionMode::default(), ToolExecutionMode::Sequential);
    }

    #[test]
    fn inference_execution_error_display_strings_are_stable() {
        assert_eq!(
            InferenceExecutionError::Provider("provider failed".into()).to_string(),
            "provider error: provider failed"
        );
        assert_eq!(
            InferenceExecutionError::RateLimited("too many requests".into()).to_string(),
            "rate limited: too many requests"
        );
        assert_eq!(
            InferenceExecutionError::Timeout("slow backend".into()).to_string(),
            "timeout: slow backend"
        );
        assert_eq!(InferenceExecutionError::Cancelled.to_string(), "cancelled");
    }
}