aether-llm 0.1.9

Multi-provider LLM abstraction layer for the Aether AI agent framework
Documentation
use async_openai::types::chat::{
    ChatCompletionMessageToolCall, ChatCompletionMessageToolCalls, ChatCompletionRequestAssistantMessage,
    ChatCompletionRequestAssistantMessageContent, ChatCompletionRequestMessage,
    ChatCompletionRequestMessageContentPartAudio, ChatCompletionRequestMessageContentPartImage,
    ChatCompletionRequestMessageContentPartText, ChatCompletionRequestSystemMessage, ChatCompletionRequestToolMessage,
    ChatCompletionRequestToolMessageContent, ChatCompletionRequestUserMessage, ChatCompletionRequestUserMessageContent,
    ChatCompletionRequestUserMessageContentPart, ChatCompletionTool, ChatCompletionTools, FunctionCall, FunctionObject,
    ImageUrl, InputAudio, InputAudioFormat,
};

use crate::{ChatMessage, ContentBlock, LlmError, Result, ToolDefinition};

fn map_message(msg: ChatMessage) -> Result<Option<ChatCompletionRequestMessage>> {
    match msg {
        ChatMessage::System { content, .. } => {
            Ok(Some(ChatCompletionRequestMessage::System(ChatCompletionRequestSystemMessage {
                content: content.into(),
                name: None,
            })))
        }
        ChatMessage::User { content, .. } => {
            Ok(Some(ChatCompletionRequestMessage::User(ChatCompletionRequestUserMessage {
                content: map_user_content(content)?,
                name: None,
            })))
        }
        ChatMessage::Assistant { content, tool_calls, .. } => {
            let openai_tool_calls: Vec<_> = tool_calls
                .into_iter()
                .map(|call| {
                    ChatCompletionMessageToolCalls::Function(ChatCompletionMessageToolCall {
                        id: call.id,
                        function: FunctionCall { name: call.name, arguments: call.arguments },
                    })
                })
                .collect();

            let tool_calls = (!openai_tool_calls.is_empty()).then_some(openai_tool_calls);

            Ok(Some(ChatCompletionRequestMessage::Assistant(ChatCompletionRequestAssistantMessage {
                content: Some(ChatCompletionRequestAssistantMessageContent::Text(content)),
                name: None,
                tool_calls,
                audio: None,
                refusal: None,
                #[allow(deprecated)]
                function_call: None,
            })))
        }
        ChatMessage::ToolCallResult(result) => {
            let (content, id) = match result {
                Ok(r) => (r.result, r.id),
                Err(e) => (e.error, e.id),
            };
            Ok(Some(ChatCompletionRequestMessage::Tool(ChatCompletionRequestToolMessage {
                content: ChatCompletionRequestToolMessageContent::Text(content),
                tool_call_id: id,
            })))
        }
        ChatMessage::Summary { content, .. } => {
            Ok(Some(ChatCompletionRequestMessage::User(ChatCompletionRequestUserMessage {
                content: format!("[Previous conversation handoff]\n\n{content}").into(),
                name: None,
            })))
        }
        ChatMessage::Error { .. } => Ok(None),
    }
}

pub fn map_messages(messages: &[ChatMessage]) -> Result<Vec<ChatCompletionRequestMessage>> {
    messages
        .iter()
        .cloned()
        .map(map_message)
        .filter_map(|result| match result {
            Ok(Some(message)) => Some(Ok(message)),
            Ok(None) => None,
            Err(error) => Some(Err(error)),
        })
        .collect()
}

pub fn map_tools(tools: &[ToolDefinition]) -> Result<Vec<ChatCompletionTools>> {
    tools.iter().map(tool_definition_to_openai).collect()
}

fn map_user_content(parts: Vec<ContentBlock>) -> Result<ChatCompletionRequestUserMessageContent> {
    let openai_parts: Vec<ChatCompletionRequestUserMessageContentPart> =
        parts.into_iter().map(map_user_content_part).collect::<Result<_>>()?;

    Ok(ChatCompletionRequestUserMessageContent::Array(openai_parts))
}

fn map_user_content_part(part: ContentBlock) -> Result<ChatCompletionRequestUserMessageContentPart> {
    match part {
        ContentBlock::Text { text } => {
            Ok(ChatCompletionRequestUserMessageContentPart::Text(ChatCompletionRequestMessageContentPartText { text }))
        }
        ContentBlock::Image { data, mime_type } => {
            Ok(ChatCompletionRequestUserMessageContentPart::ImageUrl(ChatCompletionRequestMessageContentPartImage {
                image_url: ImageUrl { url: format!("data:{mime_type};base64,{data}"), detail: None },
            }))
        }
        ContentBlock::Audio { data, mime_type } => {
            let format = map_audio_format(&mime_type)?;
            Ok(ChatCompletionRequestUserMessageContentPart::InputAudio(ChatCompletionRequestMessageContentPartAudio {
                input_audio: InputAudio { data, format },
            }))
        }
    }
}

fn map_audio_format(mime_type: &str) -> Result<InputAudioFormat> {
    match mime_type {
        "audio/wav" => Ok(InputAudioFormat::Wav),
        "audio/mpeg" | "audio/mp3" => Ok(InputAudioFormat::Mp3),
        _ => Err(LlmError::UnsupportedContent(format!(
            "OpenAI chat completions does not support {mime_type} audio input"
        ))),
    }
}

fn tool_definition_to_openai(tool: &ToolDefinition) -> Result<ChatCompletionTools> {
    let parameters = serde_json::from_str(&tool.parameters)
        .map_err(|e| LlmError::ToolParameterParsing { tool_name: tool.name.clone(), error: e.to_string() })?;

    Ok(ChatCompletionTools::Function(ChatCompletionTool {
        function: FunctionObject {
            name: tool.name.clone(),
            description: Some(tool.description.clone()),
            parameters: Some(parameters),
            strict: Some(false),
        },
    }))
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::types::IsoString;

    #[test]
    fn map_user_content_text_only() {
        let parts = vec![ContentBlock::text("Hello")];
        let result = map_user_content(parts).unwrap();
        match result {
            ChatCompletionRequestUserMessageContent::Array(parts) => {
                assert_eq!(parts.len(), 1);
                assert!(matches!(
                    &parts[0],
                    ChatCompletionRequestUserMessageContentPart::Text(t) if t.text == "Hello"
                ));
            }
            other @ ChatCompletionRequestUserMessageContent::Text(_) => panic!("Expected Array, got {other:?}"),
        }
    }

    #[test]
    fn map_user_content_with_image_produces_array() {
        let parts = vec![
            ContentBlock::text("Look at this:"),
            ContentBlock::Image { data: "aW1hZ2VkYXRh".to_string(), mime_type: "image/png".to_string() },
        ];
        let result = map_user_content(parts).unwrap();
        match result {
            ChatCompletionRequestUserMessageContent::Array(parts) => {
                assert_eq!(parts.len(), 2);
                assert!(matches!(&parts[0], ChatCompletionRequestUserMessageContentPart::Text(_)));
                match &parts[1] {
                    ChatCompletionRequestUserMessageContentPart::ImageUrl(img) => {
                        assert!(img.image_url.url.starts_with("data:image/png;base64,"));
                    }
                    other => panic!("Expected ImageUrl, got {other:?}"),
                }
            }
            other @ ChatCompletionRequestUserMessageContent::Text(_) => panic!("Expected Array, got {other:?}"),
        }
    }

    #[test]
    fn map_user_content_with_audio_produces_array() {
        let parts = vec![
            ContentBlock::text("Listen:"),
            ContentBlock::Audio { data: "YXVkaW9kYXRh".to_string(), mime_type: "audio/wav".to_string() },
        ];
        let result = map_user_content(parts).unwrap();
        match result {
            ChatCompletionRequestUserMessageContent::Array(parts) => {
                assert_eq!(parts.len(), 2);
                match &parts[1] {
                    ChatCompletionRequestUserMessageContentPart::InputAudio(aud) => {
                        assert_eq!(aud.input_audio.format, InputAudioFormat::Wav);
                    }
                    other => panic!("Expected InputAudio, got {other:?}"),
                }
            }
            other @ ChatCompletionRequestUserMessageContent::Text(_) => panic!("Expected Array, got {other:?}"),
        }
    }

    #[test]
    fn map_user_content_with_mpeg_audio_produces_mp3_format() {
        let parts = vec![
            ContentBlock::text("Listen:"),
            ContentBlock::Audio { data: "YXVkaW9kYXRh".to_string(), mime_type: "audio/mpeg".to_string() },
        ];
        let result = map_user_content(parts).unwrap();
        match result {
            ChatCompletionRequestUserMessageContent::Array(parts) => match &parts[1] {
                ChatCompletionRequestUserMessageContentPart::InputAudio(aud) => {
                    assert_eq!(aud.input_audio.format, InputAudioFormat::Mp3);
                }
                other => panic!("Expected InputAudio, got {other:?}"),
            },
            other @ ChatCompletionRequestUserMessageContent::Text(_) => panic!("Expected Array, got {other:?}"),
        }
    }

    #[test]
    fn map_user_content_with_mp3_audio_produces_mp3_format() {
        let parts = vec![
            ContentBlock::text("Listen:"),
            ContentBlock::Audio { data: "YXVkaW9kYXRh".to_string(), mime_type: "audio/mp3".to_string() },
        ];
        let result = map_user_content(parts).unwrap();
        match result {
            ChatCompletionRequestUserMessageContent::Array(parts) => match &parts[1] {
                ChatCompletionRequestUserMessageContentPart::InputAudio(aud) => {
                    assert_eq!(aud.input_audio.format, InputAudioFormat::Mp3);
                }
                other => panic!("Expected InputAudio, got {other:?}"),
            },
            other @ ChatCompletionRequestUserMessageContent::Text(_) => panic!("Expected Array, got {other:?}"),
        }
    }

    #[test]
    fn map_user_content_with_ogg_returns_unsupported_content() {
        let parts = vec![
            ContentBlock::text("Listen:"),
            ContentBlock::Audio { data: "YXVkaW9kYXRh".to_string(), mime_type: "audio/ogg".to_string() },
        ];

        assert!(matches!(map_user_content(parts), Err(LlmError::UnsupportedContent(_))));
    }

    #[test]
    fn map_text_only_user_message_unchanged() {
        let messages =
            vec![ChatMessage::User { content: vec![ContentBlock::text("Hello")], timestamp: IsoString::now() }];
        let result = map_messages(&messages).unwrap();
        assert_eq!(result.len(), 1);
    }

    #[test]
    fn map_messages_with_ogg_audio_returns_unsupported_content() {
        let messages = vec![ChatMessage::User {
            content: vec![
                ContentBlock::text("Listen:"),
                ContentBlock::Audio { data: "YXVkaW9kYXRh".to_string(), mime_type: "audio/ogg".to_string() },
            ],
            timestamp: IsoString::now(),
        }];

        assert!(matches!(map_messages(&messages), Err(LlmError::UnsupportedContent(_))));
    }

    #[test]
    fn map_tools_with_valid_json() {
        let tools = vec![ToolDefinition {
            name: "search".to_string(),
            description: "Search for things".to_string(),
            parameters: r#"{"type": "object", "properties": {"q": {"type": "string"}}}"#.to_string(),
            server: None,
        }];
        let result = map_tools(&tools);
        assert!(result.is_ok());
        assert_eq!(result.unwrap().len(), 1);
    }

    #[test]
    fn map_tools_with_invalid_json_returns_error() {
        let tools = vec![ToolDefinition {
            name: "broken_tool".to_string(),
            description: "A tool with bad params".to_string(),
            parameters: "not valid json{{{".to_string(),
            server: None,
        }];
        let result = map_tools(&tools);
        assert!(result.is_err());
        match result.unwrap_err() {
            LlmError::ToolParameterParsing { tool_name, .. } => {
                assert_eq!(tool_name, "broken_tool");
            }
            other => panic!("Expected ToolParameterParsing, got: {other}"),
        }
    }
}