llm-worker 0.2.1

A library for building autonomous LLM-powered systems
Documentation
//! Gemini リクエスト生成
//!
//! Google Gemini APIへのリクエストボディを構築

use serde::Serialize;
use serde_json::Value;

use crate::llm_client::{
    Request,
    types::{ContentPart, Message, MessageContent, Role, ToolDefinition},
};

use super::GeminiScheme;

/// Gemini APIへのリクエストボディ
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
pub(crate) struct GeminiRequest {
    /// コンテンツ(会話履歴)
    pub contents: Vec<GeminiContent>,
    /// システム指示
    #[serde(skip_serializing_if = "Option::is_none")]
    pub system_instruction: Option<GeminiContent>,
    /// ツール定義
    #[serde(skip_serializing_if = "Vec::is_empty")]
    pub tools: Vec<GeminiTool>,
    /// ツール設定
    #[serde(skip_serializing_if = "Option::is_none")]
    pub tool_config: Option<GeminiToolConfig>,
    /// 生成設定
    #[serde(skip_serializing_if = "Option::is_none")]
    pub generation_config: Option<GeminiGenerationConfig>,
}

/// Gemini コンテンツ
#[derive(Debug, Serialize)]
pub(crate) struct GeminiContent {
    /// ロール
    pub role: String,
    /// パーツ
    pub parts: Vec<GeminiPart>,
}

/// Gemini パーツ
#[derive(Debug, Serialize)]
#[serde(untagged)]
pub(crate) enum GeminiPart {
    /// テキストパーツ
    Text { text: String },
    /// 関数呼び出しパーツ
    FunctionCall {
        #[serde(rename = "functionCall")]
        function_call: GeminiFunctionCall,
    },
    /// 関数レスポンスパーツ
    FunctionResponse {
        #[serde(rename = "functionResponse")]
        function_response: GeminiFunctionResponse,
    },
}

/// Gemini 関数呼び出し
#[derive(Debug, Serialize)]
pub(crate) struct GeminiFunctionCall {
    pub name: String,
    pub args: Value,
}

/// Gemini 関数レスポンス
#[derive(Debug, Serialize)]
pub(crate) struct GeminiFunctionResponse {
    pub name: String,
    pub response: GeminiFunctionResponseContent,
}

/// Gemini 関数レスポンス内容
#[derive(Debug, Serialize)]
pub(crate) struct GeminiFunctionResponseContent {
    pub name: String,
    pub content: Value,
}

/// Gemini ツール定義
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
pub(crate) struct GeminiTool {
    /// 関数宣言
    pub function_declarations: Vec<GeminiFunctionDeclaration>,
}

/// Gemini 関数宣言
#[derive(Debug, Serialize)]
pub(crate) struct GeminiFunctionDeclaration {
    /// 関数名
    pub name: String,
    /// 説明
    #[serde(skip_serializing_if = "Option::is_none")]
    pub description: Option<String>,
    /// パラメータスキーマ
    pub parameters: Value,
}

/// Gemini ツール設定
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
pub(crate) struct GeminiToolConfig {
    /// 関数呼び出し設定
    pub function_calling_config: GeminiFunctionCallingConfig,
}

/// Gemini 関数呼び出し設定
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
pub(crate) struct GeminiFunctionCallingConfig {
    /// モード: AUTO, ANY, NONE
    #[serde(skip_serializing_if = "Option::is_none")]
    pub mode: Option<String>,
    /// ストリーミング関数呼び出し引数を有効にするか
    #[serde(skip_serializing_if = "Option::is_none")]
    pub stream_function_call_arguments: Option<bool>,
}

/// Gemini 生成設定
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
pub(crate) struct GeminiGenerationConfig {
    /// 最大出力トークン数
    #[serde(skip_serializing_if = "Option::is_none")]
    pub max_output_tokens: Option<u32>,
    /// Temperature
    #[serde(skip_serializing_if = "Option::is_none")]
    pub temperature: Option<f32>,
    /// Top P
    #[serde(skip_serializing_if = "Option::is_none")]
    pub top_p: Option<f32>,
    /// Top K
    #[serde(skip_serializing_if = "Option::is_none")]
    pub top_k: Option<u32>,
    /// ストップシーケンス
    #[serde(skip_serializing_if = "Vec::is_empty")]
    pub stop_sequences: Vec<String>,
}

impl GeminiScheme {
    /// RequestからGeminiのリクエストボディを構築
    pub(crate) fn build_request(&self, request: &Request) -> GeminiRequest {
        let mut contents = Vec::new();

        for message in &request.messages {
            contents.push(self.convert_message(message));
        }

        // システムプロンプト
        let system_instruction = request.system_prompt.as_ref().map(|s| GeminiContent {
            role: "user".to_string(), // system_instructionではroleは"user"か省略
            parts: vec![GeminiPart::Text { text: s.clone() }],
        });

        // ツール
        let tools = if request.tools.is_empty() {
            vec![]
        } else {
            vec![GeminiTool {
                function_declarations: request.tools.iter().map(|t| self.convert_tool(t)).collect(),
            }]
        };

        // ツール設定
        let tool_config = if !request.tools.is_empty() {
            Some(GeminiToolConfig {
                function_calling_config: GeminiFunctionCallingConfig {
                    mode: Some("AUTO".to_string()),
                    stream_function_call_arguments: if self.stream_function_call_arguments {
                        Some(true)
                    } else {
                        None
                    },
                },
            })
        } else {
            None
        };

        // 生成設定
        let generation_config = Some(GeminiGenerationConfig {
            max_output_tokens: request.config.max_tokens,
            temperature: request.config.temperature,
            top_p: request.config.top_p,
            top_k: request.config.top_k,
            stop_sequences: request.config.stop_sequences.clone(),
        });

        GeminiRequest {
            contents,
            system_instruction,
            tools,
            tool_config,
            generation_config,
        }
    }

    fn convert_message(&self, message: &Message) -> GeminiContent {
        let role = match message.role {
            Role::User => "user",
            Role::Assistant => "model",
        };

        let parts = match &message.content {
            MessageContent::Text(text) => vec![GeminiPart::Text { text: text.clone() }],
            MessageContent::ToolResult {
                tool_use_id,
                content,
            } => {
                // Geminiでは関数レスポンスとしてマップ
                vec![GeminiPart::FunctionResponse {
                    function_response: GeminiFunctionResponse {
                        name: tool_use_id.clone(),
                        response: GeminiFunctionResponseContent {
                            name: tool_use_id.clone(),
                            content: serde_json::Value::String(content.clone()),
                        },
                    },
                }]
            }
            MessageContent::Parts(parts) => parts
                .iter()
                .map(|p| match p {
                    ContentPart::Text { text } => GeminiPart::Text { text: text.clone() },
                    ContentPart::ToolUse { id: _, name, input } => GeminiPart::FunctionCall {
                        function_call: GeminiFunctionCall {
                            name: name.clone(),
                            args: input.clone(),
                        },
                    },
                    ContentPart::ToolResult {
                        tool_use_id,
                        content,
                    } => GeminiPart::FunctionResponse {
                        function_response: GeminiFunctionResponse {
                            name: tool_use_id.clone(),
                            response: GeminiFunctionResponseContent {
                                name: tool_use_id.clone(),
                                content: serde_json::Value::String(content.clone()),
                            },
                        },
                    },
                })
                .collect(),
        };

        GeminiContent {
            role: role.to_string(),
            parts,
        }
    }

    fn convert_tool(&self, tool: &ToolDefinition) -> GeminiFunctionDeclaration {
        GeminiFunctionDeclaration {
            name: tool.name.clone(),
            description: tool.description.clone(),
            parameters: tool.input_schema.clone(),
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_build_simple_request() {
        let scheme = GeminiScheme::new();
        let request = Request::new()
            .system("You are a helpful assistant.")
            .user("Hello!");

        let gemini_req = scheme.build_request(&request);

        assert!(gemini_req.system_instruction.is_some());
        assert_eq!(gemini_req.contents.len(), 1);
        assert_eq!(gemini_req.contents[0].role, "user");
    }

    #[test]
    fn test_build_request_with_tool() {
        let scheme = GeminiScheme::new();
        let request = Request::new().user("What's the weather?").tool(
            ToolDefinition::new("get_weather")
                .description("Get current weather")
                .input_schema(serde_json::json!({
                    "type": "object",
                    "properties": {
                        "location": { "type": "string" }
                    },
                    "required": ["location"]
                })),
        );

        let gemini_req = scheme.build_request(&request);

        assert_eq!(gemini_req.tools.len(), 1);
        assert_eq!(gemini_req.tools[0].function_declarations.len(), 1);
        assert_eq!(
            gemini_req.tools[0].function_declarations[0].name,
            "get_weather"
        );
        assert!(gemini_req.tool_config.is_some());
    }

    #[test]
    fn test_assistant_role_is_model() {
        let scheme = GeminiScheme::new();
        let request = Request::new().user("Hello").assistant("Hi there!");

        let gemini_req = scheme.build_request(&request);

        assert_eq!(gemini_req.contents.len(), 2);
        assert_eq!(gemini_req.contents[0].role, "user");
        assert_eq!(gemini_req.contents[1].role, "model");
    }
}