instructors 1.3.3

Type-safe structured output extraction from LLMs. The Rust instructor.
Documentation
use schemars::Schema;
use serde::{Deserialize, Serialize};

use super::{ImageInput, Message, RawResponse, StreamCallback};
use crate::error::{Error, Result};
use crate::schema;

#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
struct Request {
    contents: Vec<GemContent>,
    #[serde(skip_serializing_if = "Option::is_none")]
    system_instruction: Option<GemContent>,
    generation_config: GenerationConfig,
}

#[derive(Serialize)]
struct GemContent {
    role: String,
    parts: Vec<GemPart>,
}

#[derive(Serialize)]
#[serde(untagged)]
enum GemPart {
    Text {
        text: String,
    },
    InlineData {
        #[serde(rename = "inlineData")]
        inline_data: InlineData,
    },
    FileData {
        #[serde(rename = "fileData")]
        file_data: FileData,
    },
}

#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
struct InlineData {
    mime_type: String,
    data: String,
}

#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
struct FileData {
    file_uri: String,
}

#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
struct GenerationConfig {
    #[serde(skip_serializing_if = "Option::is_none")]
    temperature: Option<f64>,
    #[serde(skip_serializing_if = "Option::is_none")]
    max_output_tokens: Option<u32>,
    response_mime_type: String,
    response_schema: serde_json::Value,
}

#[derive(Deserialize)]
struct Response {
    candidates: Option<Vec<Candidate>>,
    #[serde(rename = "usageMetadata")]
    usage_metadata: Option<UsageMetadata>,
}

#[derive(Deserialize)]
struct Candidate {
    content: Option<CandidateContent>,
}

#[derive(Deserialize)]
struct CandidateContent {
    parts: Option<Vec<ResponsePart>>,
}

#[derive(Deserialize)]
struct ResponsePart {
    text: Option<String>,
}

#[derive(Deserialize)]
#[serde(rename_all = "camelCase")]
struct UsageMetadata {
    #[serde(default)]
    prompt_token_count: u32,
    #[serde(default)]
    candidates_token_count: u32,
}

fn build_parts(msg: &Message) -> Vec<GemPart> {
    let mut parts = Vec::with_capacity(1 + msg.images.len());

    for img in &msg.images {
        match img {
            ImageInput::Base64 { media_type, data } => {
                parts.push(GemPart::InlineData {
                    inline_data: InlineData {
                        mime_type: media_type.clone(),
                        data: data.clone(),
                    },
                });
            }
            ImageInput::Url(url) => {
                parts.push(GemPart::FileData {
                    file_data: FileData {
                        file_uri: url.clone(),
                    },
                });
            }
        }
    }

    parts.push(GemPart::Text {
        text: msg.content.clone(),
    });

    parts
}

fn gemini_role(role: &str) -> &str {
    match role {
        "assistant" => "model",
        other => other,
    }
}

#[allow(clippy::too_many_arguments)]
pub(crate) async fn send_gemini(
    http: &reqwest::Client,
    base_url: &str,
    api_key: &str,
    model: &str,
    system: Option<&str>,
    messages: &[Message],
    schema: &Schema,
    temperature: Option<f64>,
    max_tokens: u32,
    on_stream: StreamCallback<'_>,
) -> Result<RawResponse> {
    let streaming = on_stream.is_some();

    let contents: Vec<GemContent> = messages
        .iter()
        .map(|m| GemContent {
            role: gemini_role(&m.role).into(),
            parts: build_parts(m),
        })
        .collect();

    let sys = system
        .unwrap_or("Extract the requested information from the given text. Return valid JSON matching the provided schema.");

    let system_instruction = GemContent {
        role: "user".into(),
        parts: vec![GemPart::Text { text: sys.into() }],
    };

    let response_schema = schema::clean_for_gemini(schema);

    let body = Request {
        contents,
        system_instruction: Some(system_instruction),
        generation_config: GenerationConfig {
            temperature,
            max_output_tokens: Some(max_tokens),
            response_mime_type: "application/json".into(),
            response_schema,
        },
    };

    let method = if streaming {
        "streamGenerateContent"
    } else {
        "generateContent"
    };

    let url = if streaming {
        format!("{base_url}/models/{model}:{method}?alt=sse&key={api_key}")
    } else {
        format!("{base_url}/models/{model}:{method}?key={api_key}")
    };

    let resp = http
        .post(&url)
        .header("content-type", "application/json")
        .json(&body)
        .send()
        .await?;

    let status = resp.status();
    if !status.is_success() {
        let text = resp.text().await.unwrap_or_default();
        return Err(Error::Api {
            status: status.as_u16(),
            message: text,
        });
    }

    if streaming {
        read_stream(resp, on_stream.unwrap()).await
    } else {
        read_response(resp).await
    }
}

async fn read_response(resp: reqwest::Response) -> Result<RawResponse> {
    let data: Response = resp.json().await?;
    let usage = data.usage_metadata.unwrap_or(UsageMetadata {
        prompt_token_count: 0,
        candidates_token_count: 0,
    });

    let text = data
        .candidates
        .and_then(|c| c.into_iter().next())
        .and_then(|c| c.content)
        .and_then(|c| c.parts)
        .and_then(|p| p.into_iter().next())
        .and_then(|p| p.text)
        .ok_or_else(|| Error::Other("no content in gemini response".into()))?;

    Ok(RawResponse {
        content: text,
        input_tokens: usage.prompt_token_count,
        output_tokens: usage.candidates_token_count,
    })
}

async fn read_stream(
    resp: reqwest::Response,
    callback: &(dyn Fn(&str) + Send + Sync),
) -> Result<RawResponse> {
    use futures::StreamExt;

    let mut accumulated = String::new();
    let mut input_tokens = 0u32;
    let mut output_tokens = 0u32;
    let mut stream = resp.bytes_stream();
    let mut buffer = String::new();

    while let Some(chunk) = stream.next().await {
        let bytes = chunk?;
        buffer.push_str(&String::from_utf8_lossy(&bytes));

        while let Some(pos) = buffer.find('\n') {
            let line: String = buffer[..pos].trim_end_matches('\r').into();
            buffer.drain(..pos + 1);

            if line.is_empty() || line.starts_with(':') {
                continue;
            }

            if let Some(data) = line.strip_prefix("data: ")
                && let Ok(event) = serde_json::from_str::<Response>(data)
            {
                if let Some(usage) = event.usage_metadata {
                    input_tokens = usage.prompt_token_count;
                    output_tokens = usage.candidates_token_count;
                }

                if let Some(text) = event
                    .candidates
                    .and_then(|c| c.into_iter().next())
                    .and_then(|c| c.content)
                    .and_then(|c| c.parts)
                    .and_then(|p| p.into_iter().next())
                    .and_then(|p| p.text)
                {
                    callback(&text);
                    accumulated.push_str(&text);
                }
            }
        }
    }

    Ok(RawResponse {
        content: accumulated,
        input_tokens,
        output_tokens,
    })
}