pravah 0.1.1

Typed, stepwise agentic information flows for Rust
Documentation
use async_trait::async_trait;
use gemini_rust::{
    Content, FunctionCall as GeminiFunctionCall, FunctionCallingMode, FunctionDeclaration,
    FunctionResponse as GeminiFunctionResponse, Gemini, GenerationResponse,
    Message as GeminiMessage, Part, Role as GeminiRole, Tool as GeminiTool,
    client::Model as GeminiModel,
};
use serde_json::Value;

use super::super::tools::ToolDefinition;
use super::schema;
use super::{
    Client, ClientError, ClientOptions, ClientOutput, ClientResponse, LlmUrl, Message, Provider,
    Role, TokenUsage, ToolCall, ToolChoice, parse_json_output, validate_tools,
};

fn build_client(url: &LlmUrl) -> Result<Gemini, ClientError> {
    let api_key = if let Some(key) = &url.api_key {
        key.clone()
    } else {
        std::env::var("GEMINI_API_KEY")
            .map_err(|_| ClientError::Llm("GEMINI_API_KEY is not set".into()))?
    };
    let model = GeminiModel::Custom(url.model.clone());
    Gemini::with_model(&api_key, model).map_err(|e| ClientError::Llm(e.to_string()))
}

struct GeminiClient {
    client: Gemini,
    options: ClientOptions,
}

/// Builds provider messages from history.
fn build_gemini_messages(history: &[Message]) -> Vec<GeminiMessage> {
    let mut msgs = Vec::new();
    let mut i = 0;
    while i < history.len() {
        match &history[i].role {
            Role::System => {
                i += 1;
            }
            Role::User => {
                msgs.push(GeminiMessage::user(history[i].content.clone()));
                i += 1;
            }
            Role::Assistant => {
                msgs.push(GeminiMessage::model(&history[i].content));
                i += 1;
            }
            Role::AssistantToolCalls { calls } => {
                msgs.push(tool_calls_to_message(calls));
                i += 1;
            }
            Role::Tool { .. } => {
                let (msg, consumed) = tool_responses_to_message(history, i);
                msgs.push(msg);
                i += consumed;
            }
        }
    }
    msgs
}

fn build_tools_spec(tools: &[ToolDefinition]) -> Result<Option<GeminiTool>, ClientError> {
    if tools.is_empty() {
        return Ok(None);
    }
    let fns: Vec<FunctionDeclaration> = tools
        .iter()
        .map(build_fn_decl)
        .collect::<Result<Vec<_>, _>>()?;
    if fns.is_empty() {
        Ok(None)
    } else {
        Ok(Some(GeminiTool::with_functions(fns)))
    }
}

/// Converts an `AssistantToolCalls` history entry into a model-role message.
fn tool_calls_to_message(calls: &[ToolCall]) -> GeminiMessage {
    let parts: Vec<Part> = calls
        .iter()
        .map(|c| {
            let thought_sig = c
                .thought_signatures
                .as_ref()
                .and_then(|v| v.first())
                .cloned();
            Part::FunctionCall {
                function_call: GeminiFunctionCall::new(&c.name, c.args.clone()),
                thought_signature: thought_sig,
            }
        })
        .collect();
    GeminiMessage {
        content: Content {
            parts: Some(parts),
            role: Some(GeminiRole::Model),
        },
        role: GeminiRole::Model,
    }
}

/// Groups consecutive `Tool` history entries starting at `start` into one user-role message.
///
/// Returns the message and the number of history entries consumed.
fn tool_responses_to_message(history: &[Message], start: usize) -> (GeminiMessage, usize) {
    let mut parts = Vec::new();
    let mut i = start;
    while i < history.len() {
        let Role::Tool { call_id } = &history[i].role else {
            break;
        };
        let name = resolve_call_name(history, call_id);
        let val: Value = serde_json::from_str(&history[i].content)
            .unwrap_or_else(|_| Value::String(history[i].content.clone()));
        parts.push(Part::FunctionResponse {
            function_response: GeminiFunctionResponse::new(name, val),
        });
        i += 1;
    }
    let msg = GeminiMessage {
        content: Content {
            parts: Some(parts),
            role: Some(GeminiRole::User),
        },
        role: GeminiRole::User,
    };
    (msg, i - start)
}

/// Walks backwards through `history` to find the function name for a given `call_id`.
fn resolve_call_name<'a>(history: &'a [Message], call_id: &'a str) -> &'a str {
    for msg in history.iter().rev() {
        if let Role::AssistantToolCalls { calls } = &msg.role {
            for c in calls {
                if c.id == call_id {
                    return &c.name;
                }
            }
        }
    }
    tracing::error!(
        call_id,
        "could not resolve tool call name from history; using call_id as fallback"
    );
    call_id
}

/// Converts a `ToolDefinition` to a `FunctionDeclaration` via JSON deserialization.
fn build_fn_decl(tool: &ToolDefinition) -> Result<FunctionDeclaration, ClientError> {
    let sanitized = schema::sanitize_strict(tool.parameters.clone());
    let json = serde_json::json!({
        "name": tool.name,
        "description": tool.description,
        "parameters": sanitized,
    });
    serde_json::from_value(json).map_err(ClientError::Serialize)
}

/// Maps the raw API response to a [`ClientOutput`].
fn map_response(
    response: GenerationResponse,
    tools_enabled: bool,
) -> Result<ClientResponse, ClientError> {
    let usage = response.usage_metadata.as_ref().map(|usage| TokenUsage {
        input: usage.prompt_token_count.map(|v| v as u32),
        output: usage.candidates_token_count.map(|v| v as u32),
    });
    let provider_model = response.model_version.clone();
    let raw_metadata = Some(serde_json::json!({
        "response_id": response.response_id.clone(),
    }));
    let fcs = response.function_calls_with_thoughts();
    if !fcs.is_empty() {
        let thought_text = response.text();
        let thought = if thought_text.is_empty() {
            None
        } else {
            Some(thought_text)
        };
        let calls: Vec<ToolCall> = fcs
            .iter()
            .enumerate()
            .map(|(idx, (fc, sig))| ToolCall {
                id: format!("{}_{}", fc.name, idx),
                name: fc.name.clone(),
                args: fc.args.clone(),
                thought_signatures: sig.map(|s| vec![s.to_string()]),
            })
            .collect();
        return Ok(ClientResponse::new(
            Provider::Gemini,
            ClientOutput::ToolCalls { thought, calls },
        )
        .with_usage(usage)
        .with_provider_model(provider_model)
        .with_raw_metadata(raw_metadata));
    }
    if tools_enabled {
        let text = response.text();
        let content = if text.is_empty() { None } else { Some(text) };
        tracing::warn!(model_output = ?content, "LLM response contained no tool calls");
        return Err(ClientError::MissingToolCalls(content));
    }
    let text = response.text();
    if text.is_empty() {
        return Err(ClientError::EmptyResponse);
    }
    Ok(ClientResponse::new(
        Provider::Gemini,
        ClientOutput::Output(parse_json_output(&text)?),
    )
    .with_usage(usage)
    .with_provider_model(provider_model)
    .with_raw_metadata(raw_metadata))
}

impl GeminiClient {
    async fn call_api(
        &self,
        messages: Vec<GeminiMessage>,
        tools_enabled: bool,
        response_schema: Option<Value>,
    ) -> Result<GenerationResponse, ClientError> {
        let client = &self.client;
        let thinking_budget = if self.options.thinking { i32::MAX } else { 0 };
        let mut builder = client
            .generate_content()
            .with_thinking_budget(thinking_budget);
        if let Some(p) = &self.options.preamble {
            builder = builder.with_system_prompt(p.clone());
        }
        builder = builder.with_messages(messages);
        if tools_enabled {
            if let Some(tool_spec) = build_tools_spec(&self.options.tools)? {
                let mode = match self.options.tool_choice {
                    ToolChoice::Required => FunctionCallingMode::Any,
                    _ => FunctionCallingMode::Auto,
                };
                builder = builder
                    .with_tool(tool_spec)
                    .with_function_calling_mode(mode);
            }
        } else if let Some(schema) = response_schema {
            builder = builder
                .with_response_mime_type("application/json")
                .with_response_schema(schema);
        }
        builder
            .execute()
            .await
            .map_err(|e| ClientError::Llm(e.to_string()))
    }
}

#[async_trait]
impl Client for GeminiClient {
    async fn execute(&self, messages: &[Message]) -> Result<ClientResponse, ClientError> {
        if messages.is_empty() {
            return Err(ClientError::Validation("messages must not be empty".into()));
        }
        if matches!(
            messages.last().map(|m| &m.role),
            Some(Role::AssistantToolCalls { .. })
        ) {
            return Err(ClientError::Validation(
                "history ends with assistant tool calls without tool results".into(),
            ));
        }
        let tools_enabled =
            !self.options.tools.is_empty() && self.options.tool_choice != ToolChoice::Disabled;
        validate_tools(Provider::Gemini, &self.options.tools)?;
        let response_schema = if !tools_enabled {
            self.options
                .output_schema
                .as_ref()
                .map(|s| schema::sanitize_strict(s.clone()))
        } else {
            None
        };
        let gemini_messages = build_gemini_messages(messages);
        let response = self
            .call_api(gemini_messages, tools_enabled, response_schema)
            .await?;
        map_response(response, tools_enabled)
    }
}

/// Creates a `gemini-rust`-backed client, returning an error if the API key cannot be resolved.
pub fn new_client(url: &LlmUrl, options: ClientOptions) -> Result<Box<dyn Client>, ClientError> {
    let client = build_client(url)?;
    Ok(Box::new(GeminiClient { client, options }))
}

#[cfg(test)]
mod tests {
    use super::*;
    use serde_json::json;

    fn make_call(id: &str, name: &str) -> ToolCall {
        ToolCall {
            id: id.into(),
            name: name.into(),
            args: json!({}),
            thought_signatures: None,
        }
    }

    /// With a single user message in history, exactly one message is produced.
    #[test]
    fn build_messages_user_only() {
        let history = vec![Message::user(r#"{"text":"hi"}"#)];
        let msgs = build_gemini_messages(&history);
        assert_eq!(msgs.len(), 1);
    }

    /// Preamble is passed via system_prompt; message count reflects only history.
    #[test]
    fn build_messages_preamble_is_separate() {
        let history = vec![Message::user(r#"{"text":"hi"}"#)];
        let msgs = build_gemini_messages(&history);
        assert_eq!(msgs.len(), 1);
    }

    /// History turns appear in order before the final user message.
    #[test]
    fn build_messages_history_in_order() {
        let history = vec![
            Message::user("prev question"),
            Message::assistant("prev answer"),
            Message::user("next question"),
        ];
        let msgs = build_gemini_messages(&history);
        assert_eq!(msgs.len(), 3);
        let debug = format!("{msgs:?}");
        assert!(debug.contains("prev question"));
        assert!(debug.contains("prev answer"));
    }

    /// Tool role messages are grouped into a user-role function-response message.
    #[test]
    fn build_messages_tool_role_included() {
        let history = vec![
            Message {
                role: Role::AssistantToolCalls {
                    calls: vec![make_call("call-42", "read_file")],
                },
                content: String::new(),
                usage: None,
            },
            Message {
                role: Role::Tool {
                    call_id: "call-42".into(),
                },
                content: r#"{"temp":22}"#.into(),
                usage: None,
            },
        ];
        let msgs = build_gemini_messages(&history);
        assert_eq!(msgs.len(), 2);
        let debug = format!("{msgs:?}");
        assert!(debug.contains("read_file"));
    }

    /// Full ReAct exchange: user seed + model tool call + tool result = 3 messages.
    #[test]
    fn build_messages_continue_after_tool_result() {
        let history = vec![
            Message::user(r#"{"goal":"ship","known_context":[]}"#),
            Message {
                role: Role::AssistantToolCalls {
                    calls: vec![make_call("c1", "project_outline")],
                },
                content: String::new(),
                usage: None,
            },
            Message {
                role: Role::Tool {
                    call_id: "c1".into(),
                },
                content: r#"{"files":[]}"#.into(),
                usage: None,
            },
        ];
        let msgs = build_gemini_messages(&history);
        assert_eq!(msgs.len(), 3);
    }
}