use std::collections::HashMap;
use std::sync::Arc;
use github_copilot_sdk::generated::session_events::{SessionEventType, ToolExecutionCompleteData};
use github_copilot_sdk::handler::ApproveAllHandler;
use github_copilot_sdk::tool::ToolHandler;
use github_copilot_sdk::{
Error, SessionConfig, Tool, ToolInvocation, ToolResult, ToolResultExpanded,
};
use serde_json::json;
use tokio::sync::mpsc;
use super::support::{assistant_message_content, collect_until_idle, with_e2e_context};
#[tokio::test]
async fn should_handle_structured_toolresultobject_from_custom_tool() {
with_e2e_context(
"tool_results",
"should_handle_structured_toolresultobject_from_custom_tool",
|ctx| {
Box::pin(async move {
ctx.set_default_copilot_user();
let client = ctx.start_client().await;
let session = create_tool_session(ctx, &client, weather_tool()).await;
let answer = session
.send_and_wait("What's the weather in Paris?")
.await
.expect("send")
.expect("assistant message");
let content = assistant_message_content(&answer).to_lowercase();
assert!(content.contains("sunny") || content.contains("72"));
session.disconnect().await.expect("disconnect session");
client.stop().await.expect("stop client");
})
},
)
.await;
}
#[tokio::test]
async fn should_handle_tool_result_with_failure_resulttype() {
with_e2e_context(
"tool_results",
"should_handle_tool_result_with_failure_resulttype",
|ctx| {
Box::pin(async move {
ctx.set_default_copilot_user();
let client = ctx.start_client().await;
let session = create_tool_session(ctx, &client, check_status_tool()).await;
let answer = session
.send_and_wait("Check the status of the service using check_status. If it fails, say 'service is down'.")
.await
.expect("send")
.expect("assistant message");
assert!(assistant_message_content(&answer)
.to_lowercase()
.contains("service is down"));
session.disconnect().await.expect("disconnect session");
client.stop().await.expect("stop client");
})
},
)
.await;
}
#[tokio::test]
async fn should_preserve_tooltelemetry_and_not_stringify_structured_results_for_llm() {
with_e2e_context(
"tool_results",
"should_preserve_tooltelemetry_and_not_stringify_structured_results_for_llm",
|ctx| {
Box::pin(async move {
ctx.set_default_copilot_user();
let client = ctx.start_client().await;
let session = create_tool_session(ctx, &client, analyze_code_tool()).await;
let answer = session
.send_and_wait("Analyze the file main.ts for issues.")
.await
.expect("send")
.expect("assistant message");
assert!(
assistant_message_content(&answer)
.to_lowercase()
.contains("no issues")
);
let exchanges = ctx.exchanges();
let tool_results: Vec<_> = exchanges
.last()
.and_then(|exchange| exchange.get("request"))
.and_then(|request| request.get("messages"))
.and_then(serde_json::Value::as_array)
.expect("messages")
.iter()
.filter(|message| {
message.get("role").and_then(serde_json::Value::as_str) == Some("tool")
})
.collect();
assert_eq!(tool_results.len(), 1);
let content = tool_results[0].to_string();
assert!(!content.contains("toolTelemetry"));
assert!(!content.contains("resultType"));
session.disconnect().await.expect("disconnect session");
client.stop().await.expect("stop client");
})
},
)
.await;
}
#[tokio::test]
async fn should_handle_tool_result_with_rejected_resulttype() {
with_e2e_context(
"tool_results",
"should_handle_tool_result_with_rejected_resulttype",
|ctx| {
Box::pin(async move {
ctx.set_default_copilot_user();
let client = ctx.start_client().await;
let (call_tx, mut call_rx) = mpsc::unbounded_channel();
let session = create_tool_session(ctx, &client, deploy_tool(call_tx)).await;
let events = session.subscribe();
session
.send("Deploy the service using deploy_service. If it's rejected, tell me it was 'rejected by policy'.")
.await
.expect("send");
recv_called(&mut call_rx, "deploy tool").await;
let observed = collect_until_idle(events).await;
let complete = observed
.iter()
.find(|event| event.parsed_type() == SessionEventType::ToolExecutionComplete)
.and_then(|event| event.typed_data::<ToolExecutionCompleteData>())
.expect("tool.execution_complete");
assert!(!complete.success);
let error = complete.error.expect("tool error");
assert_eq!(error.code.as_deref(), Some("rejected"));
assert!(error.message.contains("Deployment rejected"));
session.disconnect().await.expect("disconnect session");
client.stop().await.expect("stop client");
})
},
)
.await;
}
#[tokio::test]
async fn should_handle_tool_result_with_denied_resulttype() {
with_e2e_context(
"tool_results",
"should_handle_tool_result_with_denied_resulttype",
|ctx| {
Box::pin(async move {
ctx.set_default_copilot_user();
let client = ctx.start_client().await;
let (call_tx, mut call_rx) = mpsc::unbounded_channel();
let session = create_tool_session(ctx, &client, access_secret_tool(call_tx)).await;
let events = session.subscribe();
session
.send("Use access_secret to get the API key. If access is denied, tell me it was 'access denied'.")
.await
.expect("send");
recv_called(&mut call_rx, "access secret tool").await;
let observed = collect_until_idle(events).await;
let complete = observed
.iter()
.find(|event| event.parsed_type() == SessionEventType::ToolExecutionComplete)
.and_then(|event| event.typed_data::<ToolExecutionCompleteData>())
.expect("tool.execution_complete");
assert!(!complete.success);
let error = complete.error.expect("tool error");
assert_eq!(error.code.as_deref(), Some("denied"));
assert!(error.message.contains("Access denied"));
session.disconnect().await.expect("disconnect session");
client.stop().await.expect("stop client");
})
},
)
.await;
}
async fn create_tool_session(
_ctx: &super::support::E2eContext,
client: &github_copilot_sdk::Client,
tool: Tool,
) -> github_copilot_sdk::session::Session {
let __perm = Arc::new(ApproveAllHandler);
client
.create_session(
SessionConfig::default()
.with_github_token(super::support::DEFAULT_TEST_TOKEN)
.with_permission_handler(__perm)
.with_tools(vec![tool]),
)
.await
.expect("create session")
}
async fn recv_called(receiver: &mut mpsc::UnboundedReceiver<()>, description: &'static str) {
tokio::time::timeout(std::time::Duration::from_secs(10), receiver.recv())
.await
.unwrap_or_else(|_| panic!("timed out waiting for {description}"))
.unwrap_or_else(|| panic!("{description} channel closed"));
}
fn expanded(text: impl Into<String>, result_type: impl Into<String>) -> ToolResult {
ToolResult::Expanded(ToolResultExpanded {
text_result_for_llm: text.into(),
result_type: result_type.into(),
binary_results_for_llm: None,
session_log: None,
error: None,
tool_telemetry: None,
})
}
fn weather_tool() -> Tool {
string_tool(
"get_weather",
"Gets weather for a city",
"city",
"City name",
)
.with_handler(Arc::new(WeatherTool))
}
struct WeatherTool;
#[async_trait::async_trait]
impl ToolHandler for WeatherTool {
async fn call(&self, invocation: ToolInvocation) -> Result<ToolResult, Error> {
let city = invocation
.arguments
.get("city")
.and_then(serde_json::Value::as_str)
.unwrap_or("Paris");
Ok(expanded(
format!("The weather in {city} is sunny and 72\u{b0}F"),
"success",
))
}
}
fn check_status_tool() -> Tool {
Tool::new("check_status")
.with_description("Checks the status of a service")
.with_handler(Arc::new(CheckStatusTool))
}
struct CheckStatusTool;
#[async_trait::async_trait]
impl ToolHandler for CheckStatusTool {
async fn call(&self, _invocation: ToolInvocation) -> Result<ToolResult, Error> {
let mut result = match expanded("Service unavailable", "failure") {
ToolResult::Expanded(result) => result,
_ => unreachable!(),
};
result.error = Some("API timeout".to_string());
Ok(ToolResult::Expanded(result))
}
}
fn analyze_code_tool() -> Tool {
string_tool(
"analyze_code",
"Analyzes code for issues",
"file",
"File to analyze",
)
.with_handler(Arc::new(AnalyzeCodeTool))
}
struct AnalyzeCodeTool;
#[async_trait::async_trait]
impl ToolHandler for AnalyzeCodeTool {
async fn call(&self, invocation: ToolInvocation) -> Result<ToolResult, Error> {
let file = invocation
.arguments
.get("file")
.and_then(serde_json::Value::as_str)
.unwrap_or("main.ts");
let mut result = match expanded(format!("Analysis of {file}: no issues found"), "success") {
ToolResult::Expanded(result) => result,
_ => unreachable!(),
};
result.tool_telemetry = Some(HashMap::from([(
"metrics".to_string(),
json!({ "analysisTimeMs": 150 }),
)]));
Ok(ToolResult::Expanded(result))
}
}
fn deploy_tool(call_tx: mpsc::UnboundedSender<()>) -> Tool {
Tool::new("deploy_service")
.with_description("Deploys a service")
.with_handler(Arc::new(DeployTool { call_tx }))
}
struct DeployTool {
call_tx: mpsc::UnboundedSender<()>,
}
#[async_trait::async_trait]
impl ToolHandler for DeployTool {
async fn call(&self, _invocation: ToolInvocation) -> Result<ToolResult, Error> {
let _ = self.call_tx.send(());
Ok(expanded(
"Deployment rejected: policy violation - production deployments require approval",
"rejected",
))
}
}
fn access_secret_tool(call_tx: mpsc::UnboundedSender<()>) -> Tool {
Tool::new("access_secret")
.with_description("Accesses a secret")
.with_handler(Arc::new(AccessSecretTool { call_tx }))
}
struct AccessSecretTool {
call_tx: mpsc::UnboundedSender<()>,
}
#[async_trait::async_trait]
impl ToolHandler for AccessSecretTool {
async fn call(&self, _invocation: ToolInvocation) -> Result<ToolResult, Error> {
let _ = self.call_tx.send(());
Ok(expanded(
"Access denied: insufficient permissions to read secrets",
"denied",
))
}
}
fn string_tool(
name: &str,
description: &str,
parameter: &str,
parameter_description: &str,
) -> Tool {
Tool::new(name)
.with_description(description)
.with_parameters(json!({
"type": "object",
"properties": {
parameter: {
"type": "string",
"description": parameter_description,
}
},
"required": [parameter],
}))
}