chabeau 0.7.3

A full-screen terminal chat interface that connects to various AI APIs for real-time conversations
Documentation
use crate::api::ChatMessage;
use rust_mcp_schema::{
    CreateMessageRequest, CreateMessageRequestParams, MessageMeta, Role, SamplingMessageContent,
    SamplingMessageContentBlock, TextContent,
};
use std::time::Duration;

pub fn summarize_sampling_request(request: &CreateMessageRequest) -> String {
    let message_count = request.params.messages.len();
    let max_tokens = request.params.max_tokens;
    let system_prompt = request
        .params
        .system_prompt
        .as_ref()
        .map(|prompt| format!("system prompt: {}", summarize_prompt(prompt)))
        .unwrap_or_else(|| "system prompt: none".to_string());
    format!("messages: {message_count}, maxTokens: {max_tokens}, {system_prompt}")
}

pub fn serialize_sampling_params(request: &CreateMessageRequest) -> String {
    serde_json::to_string(&request.params).unwrap_or_else(|_| "{}".to_string())
}

pub fn build_sampling_messages(request: &CreateMessageRequest) -> Result<Vec<ChatMessage>, String> {
    let mut messages = Vec::new();
    if let Some(system_prompt) = request.params.system_prompt.as_ref() {
        if !system_prompt.trim().is_empty() {
            messages.push(ChatMessage {
                role: "system".to_string(),
                content: system_prompt.clone(),
                name: None,
                tool_call_id: None,
                tool_calls: None,
            });
        }
    }

    for message in &request.params.messages {
        let content = sampling_content_to_text(&message.content)?;
        let role = match message.role {
            Role::User => "user",
            Role::Assistant => "assistant",
        };
        messages.push(ChatMessage {
            role: role.to_string(),
            content,
            name: None,
            tool_call_id: None,
            tool_calls: None,
        });
    }

    Ok(messages)
}

pub fn sampling_timeout_for_request(request: &CreateMessageRequest) -> Option<Duration> {
    sampling_timeout_from_params(&request.params)
}

fn sampling_timeout_from_params(params: &CreateMessageRequestParams) -> Option<Duration> {
    let meta = params.meta.as_ref()?;
    sampling_timeout_from_meta(meta)
}

fn sampling_timeout_from_meta(meta: &MessageMeta) -> Option<Duration> {
    let extra = meta.extra.as_ref()?;
    let mut timeout_ms: Option<u64> = None;
    let mut timeout_secs: Option<f64> = None;

    for (key, value) in extra {
        let key = key.to_ascii_lowercase();
        if key.contains("timeoutms") {
            if let Some(ms) = parse_positive_number(value) {
                timeout_ms = Some(ms);
            }
        } else if key.contains("timeout") {
            timeout_secs = parse_positive_float(value);
        }
    }

    if let Some(ms) = timeout_ms {
        return Some(Duration::from_millis(ms));
    }
    let secs = timeout_secs?;
    Some(Duration::from_millis((secs * 1000.0).ceil() as u64))
}

fn parse_positive_number(value: &serde_json::Value) -> Option<u64> {
    match value {
        serde_json::Value::Number(number) => number.as_u64(),
        serde_json::Value::String(value) => value.trim().parse::<u64>().ok(),
        _ => None,
    }
}

fn parse_positive_float(value: &serde_json::Value) -> Option<f64> {
    let number = match value {
        serde_json::Value::Number(number) => number.as_f64(),
        serde_json::Value::String(value) => value.trim().parse::<f64>().ok(),
        _ => None,
    }?;
    if number > 0.0 {
        Some(number)
    } else {
        None
    }
}

pub fn map_finish_reason(reason: Option<String>) -> Option<String> {
    let reason = reason?;
    let mapped = match reason.as_str() {
        "stop" => "endTurn",
        "length" => "maxTokens",
        "tool_calls" => "toolUse",
        "content_filter" => "stopSequence",
        _ => return Some(reason),
    };
    Some(mapped.to_string())
}

fn sampling_content_to_text(content: &SamplingMessageContent) -> Result<String, String> {
    match content {
        SamplingMessageContent::TextContent(text) => Ok(text.text.clone()),
        SamplingMessageContent::SamplingMessageContentBlock(blocks) => {
            let mut parts = Vec::new();
            for block in blocks {
                parts.push(sampling_block_to_text(block)?);
            }
            Ok(parts.join("\n"))
        }
        _ => Err("Sampling content must be text-only for this client.".to_string()),
    }
}

fn sampling_block_to_text(block: &SamplingMessageContentBlock) -> Result<String, String> {
    match block {
        SamplingMessageContentBlock::TextContent(TextContent { text, .. }) => Ok(text.clone()),
        _ => Err("Sampling content blocks must be text-only for this client.".to_string()),
    }
}

fn summarize_prompt(prompt: &str) -> String {
    let trimmed = prompt.trim();
    if trimmed.len() <= 48 {
        trimmed.to_string()
    } else {
        let mut truncated = trimmed.chars().take(48).collect::<String>();
        truncated.push('');
        truncated
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use rust_mcp_schema::{
        CreateMessageRequest, CreateMessageRequestParams, ImageContent, MessageMeta, RequestId,
        Role, SamplingMessage, SamplingMessageContent, TextContent,
    };
    use serde_json::json;

    fn base_request(messages: Vec<SamplingMessage>) -> CreateMessageRequest {
        let params = CreateMessageRequestParams {
            include_context: None,
            max_tokens: 16,
            messages,
            meta: None,
            metadata: None,
            model_preferences: None,
            stop_sequences: Vec::new(),
            system_prompt: Some("System prompt".to_string()),
            task: None,
            temperature: None,
            tool_choice: None,
            tools: Vec::new(),
        };
        CreateMessageRequest::new(RequestId::Integer(1), params)
    }

    #[test]
    fn build_sampling_messages_includes_system_prompt() {
        let message = SamplingMessage {
            role: Role::User,
            content: SamplingMessageContent::from(TextContent::new(
                "Hello".to_string(),
                None,
                None,
            )),
            meta: None,
        };
        let request = base_request(vec![message]);
        let messages = build_sampling_messages(&request).expect("messages");

        assert_eq!(messages.len(), 2);
        assert_eq!(messages[0].role, "system");
        assert_eq!(messages[0].content, "System prompt");
        assert_eq!(messages[1].role, "user");
        assert_eq!(messages[1].content, "Hello");
    }

    #[test]
    fn build_sampling_messages_rejects_image_content() {
        let image = ImageContent::new("data".to_string(), "image/png".to_string(), None, None);
        let message = SamplingMessage {
            role: Role::User,
            content: SamplingMessageContent::from(image),
            meta: None,
        };
        let request = base_request(vec![message]);
        let err = build_sampling_messages(&request)
            .err()
            .expect("should reject image");
        assert!(err.contains("text-only"));
    }

    #[test]
    fn sampling_timeout_prefers_timeout_ms() {
        let mut request = base_request(Vec::new());
        request.params.meta = Some(MessageMeta {
            progress_token: None,
            extra: Some(serde_json::Map::from_iter([(
                "server/timeoutMs".to_string(),
                json!(120_000),
            )])),
        });

        let timeout = sampling_timeout_for_request(&request).expect("timeout");
        assert_eq!(timeout.as_millis(), 120_000);
    }

    #[test]
    fn sampling_timeout_accepts_timeout_seconds() {
        let mut request = base_request(Vec::new());
        request.params.meta = Some(MessageMeta {
            progress_token: None,
            extra: Some(serde_json::Map::from_iter([(
                "timeout".to_string(),
                json!(90),
            )])),
        });

        let timeout = sampling_timeout_for_request(&request).expect("timeout");
        assert_eq!(timeout.as_secs(), 90);
    }

    #[test]
    fn sampling_timeout_ignores_invalid_values() {
        let mut request = base_request(Vec::new());
        request.params.meta = Some(MessageMeta {
            progress_token: None,
            extra: Some(serde_json::Map::from_iter([(
                "timeoutMs".to_string(),
                json!("nope"),
            )])),
        });

        assert!(sampling_timeout_for_request(&request).is_none());
    }
}