github-copilot-sdk 1.0.0-beta.4

Rust SDK for programmatic control of the GitHub Copilot CLI via JSON-RPC. Technical preview, pre-1.0.
Documentation
use std::collections::HashMap;
use std::sync::Arc;

use async_trait::async_trait;
use github_copilot_sdk::transforms::{SystemMessageTransform, TransformContext};
use github_copilot_sdk::{SectionOverride, SessionConfig, SystemMessageConfig};
use tokio::sync::mpsc;

use super::support::{DEFAULT_TEST_TOKEN, get_system_message, recv_with_timeout, with_e2e_context};

#[tokio::test]
async fn should_invoke_transform_callbacks_with_section_content() {
    with_e2e_context(
        "system_message_transform",
        "should_invoke_transform_callbacks_with_section_content",
        |ctx| {
            Box::pin(async move {
                ctx.set_default_copilot_user();
                std::fs::write(ctx.work_dir().join("test.txt"), "Hello transform!")
                    .expect("write test file");
                let (section_tx, mut section_rx) = mpsc::unbounded_channel();
                let client = ctx.start_client().await;
                let session = client
                    .create_session(
                        SessionConfig::default()
                            .with_github_token(DEFAULT_TEST_TOKEN)
                            .with_handler(Arc::new(github_copilot_sdk::handler::ApproveAllHandler))
                            .with_transform(Arc::new(RecordingTransform {
                                section_ids: vec!["identity", "tone"],
                                suffix: None,
                                section_tx,
                            })),
                    )
                    .await
                    .expect("create session");

                session
                    .send_and_wait("Read the contents of test.txt and tell me what it says")
                    .await
                    .expect("send");

                let first = recv_with_timeout(&mut section_rx, "first transform").await;
                let second = recv_with_timeout(&mut section_rx, "second transform").await;
                assert!(first.1 > 0);
                assert!(second.1 > 0);
                let sections = [first.0, second.0];
                assert!(sections.contains(&"identity".to_string()));
                assert!(sections.contains(&"tone".to_string()));

                session.disconnect().await.expect("disconnect session");
                client.stop().await.expect("stop client");
            })
        },
    )
    .await;
}

#[tokio::test]
async fn should_apply_transform_modifications_to_section_content() {
    with_e2e_context(
        "system_message_transform",
        "should_apply_transform_modifications_to_section_content",
        |ctx| {
            Box::pin(async move {
                ctx.set_default_copilot_user();
                std::fs::write(ctx.work_dir().join("hello.txt"), "Hello!")
                    .expect("write hello file");
                let (section_tx, _section_rx) = mpsc::unbounded_channel();
                let client = ctx.start_client().await;
                let session = client
                    .create_session(
                        SessionConfig::default()
                            .with_github_token(DEFAULT_TEST_TOKEN)
                            .with_handler(Arc::new(github_copilot_sdk::handler::ApproveAllHandler))
                            .with_transform(Arc::new(RecordingTransform {
                                section_ids: vec!["identity"],
                                suffix: Some("\nAlways end your reply with TRANSFORM_MARKER"),
                                section_tx,
                            })),
                    )
                    .await
                    .expect("create session");

                session
                    .send_and_wait("Read the contents of hello.txt")
                    .await
                    .expect("send");

                let exchanges = ctx.exchanges();
                assert!(!exchanges.is_empty());
                assert!(get_system_message(&exchanges[0]).contains("TRANSFORM_MARKER"));

                session.disconnect().await.expect("disconnect session");
                client.stop().await.expect("stop client");
            })
        },
    )
    .await;
}

#[tokio::test]
async fn should_work_with_static_overrides_and_transforms_together() {
    with_e2e_context(
        "system_message_transform",
        "should_work_with_static_overrides_and_transforms_together",
        |ctx| {
            Box::pin(async move {
                ctx.set_default_copilot_user();
                std::fs::write(ctx.work_dir().join("combo.txt"), "Combo test!")
                    .expect("write combo file");
                let (section_tx, mut section_rx) = mpsc::unbounded_channel();
                let mut sections = HashMap::new();
                sections.insert(
                    "safety".to_string(),
                    SectionOverride {
                        action: Some("remove".to_string()),
                        content: None,
                    },
                );
                let client = ctx.start_client().await;
                let session = client
                    .create_session(
                        SessionConfig::default()
                            .with_github_token(DEFAULT_TEST_TOKEN)
                            .with_handler(Arc::new(github_copilot_sdk::handler::ApproveAllHandler))
                            .with_system_message(
                                SystemMessageConfig::new()
                                    .with_mode("customize")
                                    .with_sections(sections),
                            )
                            .with_transform(Arc::new(RecordingTransform {
                                section_ids: vec!["identity"],
                                suffix: None,
                                section_tx,
                            })),
                    )
                    .await
                    .expect("create session");

                session
                    .send_and_wait("Read the contents of combo.txt and tell me what it says")
                    .await
                    .expect("send");

                let (section, content_len) =
                    recv_with_timeout(&mut section_rx, "identity transform").await;
                assert_eq!(section, "identity");
                assert!(content_len > 0);

                session.disconnect().await.expect("disconnect session");
                client.stop().await.expect("stop client");
            })
        },
    )
    .await;
}

struct RecordingTransform {
    section_ids: Vec<&'static str>,
    suffix: Option<&'static str>,
    section_tx: mpsc::UnboundedSender<(String, usize)>,
}

#[async_trait]
impl SystemMessageTransform for RecordingTransform {
    fn section_ids(&self) -> Vec<String> {
        self.section_ids
            .iter()
            .map(|section| (*section).to_string())
            .collect()
    }

    async fn transform_section(
        &self,
        section_id: &str,
        content: &str,
        _ctx: TransformContext,
    ) -> Option<String> {
        let _ = self
            .section_tx
            .send((section_id.to_string(), content.len()));
        Some(match self.suffix {
            Some(suffix) => format!("{content}{suffix}"),
            None => content.to_string(),
        })
    }
}