github-copilot-sdk 1.0.0-beta.7

Rust SDK for programmatic control of the GitHub Copilot CLI via JSON-RPC. Technical preview, pre-1.0.
//! Define custom tools and expose them to the Copilot agent.
//!
//! Registers two tools — `get_weather` (typed params via schemars) and
//! `roll_dice` (manual schema) — then asks the agent a question that
//! triggers tool use.
//!
//! Requires the `derive` feature for typed parameter schemas:
//!
//! ```sh
//! cargo run -p github-copilot-sdk --example tool_server --features derive
//! ```

// Gate the entire example behind the `derive` feature so it compiles
// (as a stub that prints the required feature flag) when clippy/check
// runs without the feature.
#[cfg(not(feature = "derive"))]
fn main() {
    eprintln!("This example requires the `derive` feature:");
    eprintln!("  cargo run -p github-copilot-sdk --example tool_server --features derive");
    std::process::exit(1);
}

#[cfg(feature = "derive")]
use std::sync::Arc;
#[cfg(feature = "derive")]
use std::time::Duration;

#[cfg(feature = "derive")]
use async_trait::async_trait;
#[cfg(feature = "derive")]
use github_copilot_sdk::handler::ApproveAllHandler;
#[cfg(feature = "derive")]
use github_copilot_sdk::tool::{JsonSchema, ToolHandler, schema_for};
#[cfg(feature = "derive")]
use github_copilot_sdk::types::{MessageOptions, SessionConfig, Tool, ToolInvocation, ToolResult};
#[cfg(feature = "derive")]
use github_copilot_sdk::{Client, ClientOptions, Error};
#[cfg(feature = "derive")]
use serde::Deserialize;

// ---------------------------------------------------------------------------
// Tool 1: get_weather — typed parameters derived from a Rust struct
// ---------------------------------------------------------------------------

#[cfg(feature = "derive")]
#[derive(Deserialize, JsonSchema)]
struct GetWeatherParams {
    /// City name (e.g. "Seattle").
    city: String,
    /// Temperature unit: "celsius" or "fahrenheit".
    unit: Option<String>,
}

#[cfg(feature = "derive")]
struct GetWeatherTool;

#[cfg(feature = "derive")]
#[async_trait]
impl ToolHandler for GetWeatherTool {
    async fn call(&self, invocation: ToolInvocation) -> Result<ToolResult, Error> {
        let params: GetWeatherParams = serde_json::from_value(invocation.arguments)?;
        let unit = params.unit.as_deref().unwrap_or("celsius");
        // Stub response — a real implementation would call a weather API.
        let reply = format!(
            "Weather in {}: 18°{}, partly cloudy",
            params.city,
            if unit == "fahrenheit" { "F" } else { "C" },
        );
        Ok(ToolResult::Text(reply))
    }
}

// ---------------------------------------------------------------------------
// Tool 2: roll_dice — manual JSON Schema
// ---------------------------------------------------------------------------

#[cfg(feature = "derive")]
struct RollDiceTool;

#[cfg(feature = "derive")]
#[async_trait]
impl ToolHandler for RollDiceTool {
    async fn call(&self, invocation: ToolInvocation) -> Result<ToolResult, Error> {
        let sides = invocation
            .arguments
            .get("sides")
            .and_then(|v| v.as_u64())
            .unwrap_or(6)
            .clamp(1, 1000) as u32;
        let count = invocation
            .arguments
            .get("count")
            .and_then(|v| v.as_u64())
            .unwrap_or(1)
            .clamp(1, 100) as u32;

        let mut total = 0u32;
        let mut rolls = Vec::with_capacity(count as usize);
        for _ in 0..count {
            // Simple deterministic "random" for the example.
            let roll = (std::time::SystemTime::now()
                .duration_since(std::time::UNIX_EPOCH)
                .unwrap_or_default()
                .subsec_nanos()
                % sides)
                + 1;
            rolls.push(roll);
            total += roll;
        }

        Ok(ToolResult::Text(format!(
            "Rolled {count}d{sides}: {rolls:?} = {total}"
        )))
    }
}

// ---------------------------------------------------------------------------
// Main
// ---------------------------------------------------------------------------

#[cfg(feature = "derive")]
#[tokio::main]
async fn main() -> Result<(), github_copilot_sdk::Error> {
    let client = Client::start(ClientOptions::default()).await?;

    let config = SessionConfig::default()
        .with_permission_handler(Arc::new(ApproveAllHandler))
        .with_tools(vec![
            Tool::new("get_weather")
                .with_description("Get the current weather for a city.")
                .with_parameters(schema_for::<GetWeatherParams>())
                .with_handler(Arc::new(GetWeatherTool)),
            Tool::new("roll_dice")
                .with_description("Roll one or more dice and return the total.")
                .with_parameters(serde_json::json!({
                    "type": "object",
                    "properties": {
                        "sides": { "type": "integer", "description": "Number of sides per die (default 6, max 1000)." },
                        "count": { "type": "integer", "description": "Number of dice to roll (default 1, max 100)." }
                    }
                }))
                .with_handler(Arc::new(RollDiceTool)),
        ]);
    let session = client.create_session(config).await?;

    println!(
        "Session {} — asking about weather + dice...\n",
        session.id()
    );

    let response = session
        .send_and_wait(
            MessageOptions::new("What's the weather in Seattle? Also roll 3d20 for me.")
                .with_wait_timeout(Duration::from_secs(60)),
        )
        .await?;

    if let Some(event) = response {
        let text = event
            .data
            .get("content")
            .and_then(|c| c.as_str())
            .unwrap_or("<no response>");
        println!("{text}");
    }

    session.disconnect().await?;
    Ok(())
}