Skip to main content

github_copilot_sdk/
transforms.rs

1//! System message transform callbacks for customizing agent prompts.
2//!
3//! Implement [`SystemMessageTransform`](crate::transforms::SystemMessageTransform) to intercept and modify system prompt
4//! sections during session creation. The CLI sends the current content for
5//! each section the transform registered, and the SDK returns the modified
6//! content.
7
8use std::collections::HashMap;
9
10use async_trait::async_trait;
11use serde::{Deserialize, Serialize};
12
13use crate::types::SessionId;
14
15/// Context provided to every transform invocation.
16#[derive(Debug, Clone)]
17pub struct TransformContext {
18    /// The session being created or resumed.
19    pub session_id: SessionId,
20}
21
22/// Handles `systemMessage.transform` RPC requests from the CLI.
23///
24/// The CLI sends these during session creation/resumption when the session's
25/// `SystemMessageConfig` contains sections with `action: "transform"`. For each
26/// such section, the CLI provides the current content and expects the SDK to
27/// return the (possibly modified) content.
28///
29/// Implement this trait and pass it to [`Client::create_session`](crate::Client::create_session) /
30/// [`Client::resume_session`](crate::Client::resume_session) to participate in system message customization.
31///
32/// # Example
33///
34/// ```ignore
35/// struct MyTransform;
36///
37/// #[async_trait::async_trait]
38/// impl SystemMessageTransform for MyTransform {
39///     fn section_ids(&self) -> Vec<String> {
40///         vec!["instructions".to_string()]
41///     }
42///
43///     async fn transform_section(
44///         &self,
45///         _section_id: &str,
46///         content: &str,
47///         _ctx: TransformContext,
48///     ) -> Option<String> {
49///         Some(format!("{content}\n\nAlways be concise."))
50///     }
51/// }
52/// ```
53#[async_trait]
54pub trait SystemMessageTransform: Send + Sync + 'static {
55    /// Section IDs this transform handles.
56    ///
57    /// The SDK injects `action: "transform"` entries into the
58    /// [`SystemMessageConfig`](crate::types::SystemMessageConfig) wire format
59    /// for each returned ID.
60    fn section_ids(&self) -> Vec<String>;
61
62    /// Transform a section's content. Return `Some(new_content)` to modify the
63    /// section, or `None` to pass through unchanged.
64    async fn transform_section(
65        &self,
66        section_id: &str,
67        content: &str,
68        ctx: TransformContext,
69    ) -> Option<String>;
70}
71
72/// Wire format for a single section in the transform request/response.
73#[derive(Debug, Clone, Serialize, Deserialize)]
74pub(crate) struct TransformSection {
75    pub(crate) content: String,
76}
77
78/// Wire format for the `systemMessage.transform` response.
79#[derive(Debug, Clone, Serialize)]
80pub(crate) struct TransformResponse {
81    pub(crate) sections: HashMap<String, TransformSection>,
82}
83
84/// Apply transforms to the incoming sections map, returning the response.
85///
86/// For each section, calls the matching transform if the implementor returns
87/// `Some`; otherwise passes through the original content.
88pub(crate) async fn dispatch_transform(
89    transform: &dyn SystemMessageTransform,
90    session_id: &SessionId,
91    sections: HashMap<String, TransformSection>,
92) -> TransformResponse {
93    let ctx = TransformContext {
94        session_id: session_id.clone(),
95    };
96
97    let mut result = HashMap::with_capacity(sections.len());
98    for (section_id, data) in sections {
99        let content = match transform
100            .transform_section(&section_id, &data.content, ctx.clone())
101            .await
102        {
103            Some(transformed) => transformed,
104            None => data.content,
105        };
106        result.insert(section_id, TransformSection { content });
107    }
108
109    TransformResponse { sections: result }
110}
111
112#[cfg(test)]
113mod tests {
114    use super::*;
115
116    struct TestTransform;
117
118    #[async_trait]
119    impl SystemMessageTransform for TestTransform {
120        fn section_ids(&self) -> Vec<String> {
121            vec!["instructions".to_string(), "context".to_string()]
122        }
123
124        async fn transform_section(
125            &self,
126            section_id: &str,
127            content: &str,
128            _ctx: TransformContext,
129        ) -> Option<String> {
130            match section_id {
131                "instructions" => Some(format!("[modified] {content}")),
132                _ => None,
133            }
134        }
135    }
136
137    #[tokio::test]
138    async fn dispatch_applies_matching_transform() {
139        let transform = TestTransform;
140        let mut sections = HashMap::new();
141        sections.insert(
142            "instructions".to_string(),
143            TransformSection {
144                content: "be helpful".to_string(),
145            },
146        );
147
148        let response = dispatch_transform(&transform, &SessionId::new("sess-1"), sections).await;
149        assert_eq!(
150            response.sections["instructions"].content,
151            "[modified] be helpful"
152        );
153    }
154
155    #[tokio::test]
156    async fn dispatch_passes_through_unhandled_section() {
157        let transform = TestTransform;
158        let mut sections = HashMap::new();
159        sections.insert(
160            "context".to_string(),
161            TransformSection {
162                content: "original context".to_string(),
163            },
164        );
165
166        let response = dispatch_transform(&transform, &SessionId::new("sess-1"), sections).await;
167        assert_eq!(response.sections["context"].content, "original context");
168    }
169
170    #[tokio::test]
171    async fn dispatch_unknown_section_passes_through() {
172        let transform = TestTransform;
173        let mut sections = HashMap::new();
174        sections.insert(
175            "unknown".to_string(),
176            TransformSection {
177                content: "mystery".to_string(),
178            },
179        );
180
181        let response = dispatch_transform(&transform, &SessionId::new("sess-1"), sections).await;
182        assert_eq!(response.sections["unknown"].content, "mystery");
183    }
184
185    #[tokio::test]
186    async fn dispatch_mixed_sections() {
187        let transform = TestTransform;
188        let mut sections = HashMap::new();
189        sections.insert(
190            "instructions".to_string(),
191            TransformSection {
192                content: "help me".to_string(),
193            },
194        );
195        sections.insert(
196            "context".to_string(),
197            TransformSection {
198                content: "some context".to_string(),
199            },
200        );
201        sections.insert(
202            "other".to_string(),
203            TransformSection {
204                content: "other stuff".to_string(),
205            },
206        );
207
208        let response = dispatch_transform(&transform, &SessionId::new("sess-1"), sections).await;
209        assert_eq!(
210            response.sections["instructions"].content,
211            "[modified] help me"
212        );
213        assert_eq!(response.sections["context"].content, "some context");
214        assert_eq!(response.sections["other"].content, "other stuff");
215    }
216
217    #[tokio::test]
218    async fn section_ids_returns_registered_sections() {
219        let transform = TestTransform;
220        let ids = transform.section_ids();
221        assert_eq!(ids, vec!["instructions", "context"]);
222    }
223}