github_copilot_sdk/
transforms.rs1use std::collections::HashMap;
9
10use async_trait::async_trait;
11use serde::{Deserialize, Serialize};
12
13use crate::types::SessionId;
14
15#[derive(Debug, Clone)]
17pub struct TransformContext {
18 pub session_id: SessionId,
20}
21
22#[async_trait]
54pub trait SystemMessageTransform: Send + Sync + 'static {
55 fn section_ids(&self) -> Vec<String>;
61
62 async fn transform_section(
65 &self,
66 section_id: &str,
67 content: &str,
68 ctx: TransformContext,
69 ) -> Option<String>;
70}
71
72#[derive(Debug, Clone, Serialize, Deserialize)]
74pub(crate) struct TransformSection {
75 pub(crate) content: String,
76}
77
78#[derive(Debug, Clone, Serialize)]
80pub(crate) struct TransformResponse {
81 pub(crate) sections: HashMap<String, TransformSection>,
82}
83
84pub(crate) async fn dispatch_transform(
89 transform: &dyn SystemMessageTransform,
90 session_id: &SessionId,
91 sections: HashMap<String, TransformSection>,
92) -> TransformResponse {
93 let ctx = TransformContext {
94 session_id: session_id.clone(),
95 };
96
97 let mut result = HashMap::with_capacity(sections.len());
98 for (section_id, data) in sections {
99 let content = match transform
100 .transform_section(§ion_id, &data.content, ctx.clone())
101 .await
102 {
103 Some(transformed) => transformed,
104 None => data.content,
105 };
106 result.insert(section_id, TransformSection { content });
107 }
108
109 TransformResponse { sections: result }
110}
111
112#[cfg(test)]
113mod tests {
114 use super::*;
115
116 struct TestTransform;
117
118 #[async_trait]
119 impl SystemMessageTransform for TestTransform {
120 fn section_ids(&self) -> Vec<String> {
121 vec!["instructions".to_string(), "context".to_string()]
122 }
123
124 async fn transform_section(
125 &self,
126 section_id: &str,
127 content: &str,
128 _ctx: TransformContext,
129 ) -> Option<String> {
130 match section_id {
131 "instructions" => Some(format!("[modified] {content}")),
132 _ => None,
133 }
134 }
135 }
136
137 #[tokio::test]
138 async fn dispatch_applies_matching_transform() {
139 let transform = TestTransform;
140 let mut sections = HashMap::new();
141 sections.insert(
142 "instructions".to_string(),
143 TransformSection {
144 content: "be helpful".to_string(),
145 },
146 );
147
148 let response = dispatch_transform(&transform, &SessionId::new("sess-1"), sections).await;
149 assert_eq!(
150 response.sections["instructions"].content,
151 "[modified] be helpful"
152 );
153 }
154
155 #[tokio::test]
156 async fn dispatch_passes_through_unhandled_section() {
157 let transform = TestTransform;
158 let mut sections = HashMap::new();
159 sections.insert(
160 "context".to_string(),
161 TransformSection {
162 content: "original context".to_string(),
163 },
164 );
165
166 let response = dispatch_transform(&transform, &SessionId::new("sess-1"), sections).await;
167 assert_eq!(response.sections["context"].content, "original context");
168 }
169
170 #[tokio::test]
171 async fn dispatch_unknown_section_passes_through() {
172 let transform = TestTransform;
173 let mut sections = HashMap::new();
174 sections.insert(
175 "unknown".to_string(),
176 TransformSection {
177 content: "mystery".to_string(),
178 },
179 );
180
181 let response = dispatch_transform(&transform, &SessionId::new("sess-1"), sections).await;
182 assert_eq!(response.sections["unknown"].content, "mystery");
183 }
184
185 #[tokio::test]
186 async fn dispatch_mixed_sections() {
187 let transform = TestTransform;
188 let mut sections = HashMap::new();
189 sections.insert(
190 "instructions".to_string(),
191 TransformSection {
192 content: "help me".to_string(),
193 },
194 );
195 sections.insert(
196 "context".to_string(),
197 TransformSection {
198 content: "some context".to_string(),
199 },
200 );
201 sections.insert(
202 "other".to_string(),
203 TransformSection {
204 content: "other stuff".to_string(),
205 },
206 );
207
208 let response = dispatch_transform(&transform, &SessionId::new("sess-1"), sections).await;
209 assert_eq!(
210 response.sections["instructions"].content,
211 "[modified] help me"
212 );
213 assert_eq!(response.sections["context"].content, "some context");
214 assert_eq!(response.sections["other"].content, "other stuff");
215 }
216
217 #[tokio::test]
218 async fn section_ids_returns_registered_sections() {
219 let transform = TestTransform;
220 let ids = transform.section_ids();
221 assert_eq!(ids, vec!["instructions", "context"]);
222 }
223}