oxi-ai 0.1.2-alpha

Unified LLM API — multi-provider streaming interface for AI coding assistants
Documentation
//! High-level API for oxi-ai
//!
//! Provides convenient functions for common LLM interactions.

use futures::StreamExt;
use crate::{
    Context, Model, StreamOptions, ProviderEvent, AssistantMessage, 
    ContentBlock, TextContent, ToolCall,
};
use crate::error::{Error, ProviderError};

/// High-level complete function that collects all streaming events
/// and returns the final assistant message.
///
/// # Arguments
/// * `model` - The model to use
/// * `context` - The conversation context
/// * `options` - Optional streaming options
///
/// # Returns
/// The final assistant message containing all content blocks
pub async fn complete(
    model: &Model,
    context: &Context,
    options: Option<StreamOptions>,
) -> std::result::Result<AssistantMessage, Error> {
    use crate::providers::stream;
    
    let mut stream = stream(model, context, options).await?;
    
    let mut final_message: Option<AssistantMessage> = None;
    let mut text_buffer = String::new();
    let mut current_text_index: Option<usize> = None;
    let mut tool_calls: Vec<(usize, ToolCall)> = Vec::new();
    
    while let Some(event) = stream.next().await {
        match event {
            ProviderEvent::Start { partial } => {
                final_message = Some(partial);
            }
            ProviderEvent::TextStart { content_index, partial } => {
                if final_message.is_none() {
                    final_message = Some(partial);
                }
                current_text_index = Some(content_index);
                text_buffer.clear();
            }
            ProviderEvent::TextDelta { delta, content_index, .. } => {
                text_buffer.push_str(&delta);
                if current_text_index != Some(content_index) {
                    // New text block started
                    if let Some(idx) = current_text_index {
                        // Save previous text block
                        if !text_buffer.is_empty() {
                            push_text_block(&mut final_message, idx, &text_buffer);
                        }
                    }
                    current_text_index = Some(content_index);
                    text_buffer.clear();
                }
                text_buffer.push_str(&delta);
            }
            ProviderEvent::TextEnd { content_index, content, .. } => {
                push_text_block(&mut final_message, content_index, &content);
            }
            ProviderEvent::ThinkingStart { content_index: _, partial } => {
                if final_message.is_none() {
                    final_message = Some(partial);
                }
            }
            ProviderEvent::ThinkingDelta { delta, content_index, .. } => {
                // Append thinking content
                if let Some(ref mut msg) = final_message {
                    // Find or create thinking block
                    let content = ContentBlock::Thinking(crate::ThinkingContent {
                        content_type: crate::ThinkingContentType::Thinking,
                        thinking: delta,
                        thinking_signature: None,
                        redacted: None,
                    });
                    if content_index >= msg.content.len() {
                        msg.content.push(content);
                    }
                }
            }
            ProviderEvent::ThinkingEnd { content_index, content, .. } => {
                if let Some(ref mut msg) = final_message {
                    let thinking = ContentBlock::Thinking(crate::ThinkingContent {
                        content_type: crate::ThinkingContentType::Thinking,
                        thinking: content,
                        thinking_signature: None,
                        redacted: None,
                    });
                    if content_index >= msg.content.len() {
                        msg.content.push(thinking);
                    }
                }
            }
            ProviderEvent::ToolCallStart { content_index, partial } => {
                if final_message.is_none() {
                    final_message = Some(partial);
                }
                // Initialize tool call
                let tc = ToolCall {
                    content_type: crate::ToolCallType::ToolCall,
                    id: format!("tool_call_{}", content_index),
                    name: String::new(),
                    arguments: serde_json::json!({}),
                    thought_signature: None,
                };
                tool_calls.push((content_index, tc));
            }
            ProviderEvent::ToolCallDelta { delta, content_index, .. } => {
                // Accumulate tool call arguments
                if let Some((_, tc)) = tool_calls.iter_mut().find(|(idx, _)| *idx == content_index) {
                    // Parse the accumulated args
                    let current_args = tc.arguments.to_string() + &delta;
                    if let Ok(parsed) = serde_json::from_str(&current_args) {
                        tc.arguments = parsed;
                    }
                }
            }
            ProviderEvent::ToolCallEnd { content_index, tool_call, .. } => {
                // Update or add tool call
                if let Some((_, tc)) = tool_calls.iter_mut().find(|(idx, _)| *idx == content_index) {
                    *tc = tool_call.clone();
                }
                // Add to final message content
                push_tool_call(&mut final_message, content_index, tool_call.clone());
            }
            ProviderEvent::Done { message, .. } => {
                // Finalize any remaining text
                if let Some(idx) = current_text_index {
                    if !text_buffer.is_empty() {
                        push_text_block(&mut final_message, idx, &text_buffer);
                    }
                }
                
                // Add any pending tool calls
                for (content_index, tc) in &tool_calls {
                    push_tool_call(&mut final_message, *content_index, tc.clone());
                }
                
                final_message = Some(message);
                break;
            }
            ProviderEvent::Error { error, .. } => {
                return Err(Error::Provider(ProviderError::StreamError(error.error_message.unwrap_or_else(|| "Unknown error".to_string()))));
            }
        }
    }
    
    final_message.ok_or_else(|| Error::Provider(ProviderError::StreamError("Stream ended without message".to_string())))
}

/// Push a text block to the message content
fn push_text_block(msg: &mut Option<AssistantMessage>, index: usize, text: &str) {
    if let Some(ref mut m) = msg {
        let content = ContentBlock::Text(TextContent {
            content_type: crate::TextContentType::Text,
            text: text.to_string(),
        });
        
        // Ensure the content array is large enough
        while m.content.len() <= index {
            m.content.push(ContentBlock::Text(TextContent {
                content_type: crate::TextContentType::Text,
                text: String::new(),
            }));
        }
        
        // Append text to existing block
        if let ContentBlock::Text(t) = &mut m.content[index] {
            if t.text.is_empty() {
                *t = TextContent::new(text);
            } else {
                t.text.push_str(text);
            }
        } else {
            m.content[index] = content;
        }
    }
}

/// Push a tool call block to the message content
fn push_tool_call(msg: &mut Option<AssistantMessage>, index: usize, tool_call: ToolCall) {
    if let Some(ref mut m) = msg {
        while m.content.len() <= index {
            m.content.push(ContentBlock::Text(TextContent::new("")));
        }
        m.content[index] = ContentBlock::ToolCall(tool_call);
    }
}

/// Token estimation utilities
pub mod tokens {
    /// Estimate token count using a hybrid algorithm that combines
    /// character-based and word-based heuristics.
    ///
    /// The estimator accounts for:
    /// - **CJK characters** (1 token per character – ideographic languages
    ///   tokenize nearly 1:1 with modern BPE tokenizers)
    /// - **Punctuation & symbols** (~1.5 tokens per character – they tend
    ///   to form short, independent tokens)
    /// - **Common ASCII** (~0.25 tokens per character, i.e. ~4 chars/token)
    /// - **Whitespace** overhead (~1 token per whitespace-separated word)
    ///
    /// For typical mixed English source code and prose this gives results
    /// within ±10% of tiktoken outputs for GPT-4-class tokenizers.
    ///
    /// # Arguments
    /// * `text` - The text to estimate tokens for
    ///
    /// # Returns
    /// Estimated token count
    pub fn estimate(text: &str) -> usize {
        if text.is_empty() {
            return 0;
        }

        let mut cjk_chars: usize = 0;
        let mut ascii_or_latin_chars: usize = 0;
        let mut punct_chars: usize = 0;
        let mut whitespace_words: usize = 0;
        let mut in_word = false;

        for ch in text.chars() {
            if ch.is_whitespace() {
                if in_word {
                    whitespace_words += 1;
                    in_word = false;
                }
            } else {
                in_word = true;
                if is_cjk(ch) {
                    cjk_chars += 1;
                } else if is_punctuation(ch) {
                    punct_chars += 1;
                } else {
                    ascii_or_latin_chars += 1;
                }
            }
        }
        // Count trailing word if text doesn't end with whitespace
        if in_word {
            whitespace_words += 1;
        }

        // CJK: ~1 token per character
        let cjk_tokens = cjk_chars;
        // Punctuation & symbols: ~1.5 tokens per char (round to 3 per 2)
        let punct_tokens = (punct_chars * 3 + 1) / 2;
        // ASCII / Latin: ~4 chars per token
        let ascii_tokens = (ascii_or_latin_chars + 3) / 4;
        // Whitespace word-boundary tokens (BPE adds ~1 overhead per word)
        let ws_tokens = whitespace_words / 8;

        cjk_tokens + punct_tokens + ascii_tokens + ws_tokens
    }

    /// Check if a character is a CJK ideograph.
    fn is_cjk(ch: char) -> bool {
        matches!(ch,
            '\u{4E00}'..='\u{9FFF}'   |  // CJK Unified Ideographs
            '\u{3400}'..='\u{4DBF}'   |  // CJK Unified Ideographs Extension A
            '\u{20000}'..='\u{2A6DF}' |  // CJK Unified Ideographs Extension B
            '\u{2A700}'..='\u{2B73F}' |  // CJK Unified Ideographs Extension C
            '\u{2B740}'..='\u{2B81F}' |  // CJK Unified Ideographs Extension D
            '\u{F900}'..='\u{FAFF}'   |  // CJK Compatibility Ideographs
            '\u{2F800}'..='\u{2FA1F}' |  // CJK Compatibility Ideographs Supplement
            '\u{3000}'..='\u{303F}'   |  // CJK Symbols and Punctuation
            '\u{3040}'..='\u{309F}'   |  // Hiragana
            '\u{30A0}'..='\u{30FF}'   |  // Katakana
            '\u{AC00}'..='\u{D7AF}'      // Hangul Syllables
        )
    }

    /// Check if a character is punctuation or a symbol that tends to
    /// tokenize into short, separate tokens.
    fn is_punctuation(ch: char) -> bool {
        ch.is_ascii_punctuation()
            || matches!(ch,
                '\u{201C}' | '\u{201D}' | '\u{2018}' | '\u{2019}' | '\u{2026}' | '\u{2013}' | '\u{2014}' | '\u{00AB}' | '\u{00BB}' |
                '\u{00B7}' | '\u{2022}' | '\u{203B}' | '\u{2192}' | '\u{2190}' | '\u{21D2}' | '\u{2194}' |
                '\\' | '|' | '~' | '^' | '`'
            )
    }

    /// Estimate tokens based on word count.
    ///
    /// Uses the improved hybrid estimator internally, but provided
    /// as a simpler word-based fallback.
    ///
    /// # Arguments
    /// * `text` - The text to estimate tokens for
    ///
    /// # Returns
    /// Estimated token count
    pub fn estimate_words(text: &str) -> usize {
        let word_count = text.split_whitespace().count();
        // ~1.3 tokens per word for English, higher for mixed content
        let per_word = if text.chars().any(is_cjk) { 1.6 } else { 1.3 };
        (word_count as f64 * per_word) as usize
    }

    /// Calculate context length usage percentage.
    ///
    /// # Arguments
    /// * `text` - The text to measure
    /// * `context_window` - The model's context window size
    ///
    /// # Returns
    /// Percentage of context window used (0.0 to 1.0)
    pub fn context_usage(text: &str, context_window: usize) -> f64 {
        if context_window == 0 {
            return 0.0;
        }
        (estimate(text) as f64 / context_window as f64).min(1.0)
    }

    #[cfg(test)]
    mod tests {
        use super::*;

        #[test]
        fn estimate_empty_string() {
            assert_eq!(estimate(""), 0);
        }

        #[test]
        fn estimate_plain_english() {
            // "Hello world, this is a test." ≈ 8 tokens (GPT-4 tiktoken)
            let tokens = estimate("Hello world, this is a test.");
            // Should be in a reasonable range (5–12)
            assert!(tokens >= 4 && tokens <= 14,
                "expected 4–14 tokens for plain English sentence, got {}", tokens);
        }

        #[test]
        fn estimate_cjk() {
            // Each CJK char ≈ 1 token
            let tokens = estimate("\u{4F60}\u{597D}\u{4E16}\u{754C}\u{6D4B}\u{8BD5}");
            assert!(tokens >= 4,
                "expected >= 4 tokens for 5 CJK chars, got {}", tokens);
        }

        #[test]
        fn estimate_code() {
            let code = "fn main() { println!(\"hello\"); }";
            let tokens = estimate(code);
            // Code is punctuation-heavy; expect reasonable estimate
            assert!(tokens >= 4 && tokens <= 20,
                "expected 4–20 tokens for code snippet, got {}", tokens);
        }

        #[test]
        fn estimate_longer_than_naive() {
            // The hybrid estimator should give higher (more accurate) counts
            // than the old `text.len() / 4` for punctuation-heavy text.
            let text = "{ \"key\": \"value\" }";
            let hybrid = estimate(text);
            let naive = text.len() / 4;
            // Hybrid should be positive and in a reasonable range
            assert!(hybrid > 0);
            // For this short punctuation-heavy string, hybrid will be higher
            // than naive but should not exceed 10x
            assert!(hybrid <= naive * 10, "hybrid={} naive={}", hybrid, naive);
        }

        #[test]
        fn context_usage_clamped() {
            assert_eq!(context_usage("short", 0), 0.0);
            assert!(context_usage("hello", 100000) < 1.0);
        }
    }
}

// Re-export estimate_tokens as the main function
pub use tokens::estimate as estimate_tokens;