1use async_trait::async_trait;
8
9use crate::llm::{LlmRequest, LlmResponse};
10
11#[derive(Debug, thiserror::Error)]
13pub enum ProcessorError {
14 #[error("Processor error: {0}")]
16 Processing(String),
17}
18
19#[async_trait]
21pub trait RequestProcessor: Send + Sync {
22 fn name(&self) -> &str;
24
25 async fn process_request(&self, request: LlmRequest) -> Result<LlmRequest, ProcessorError>;
27}
28
29#[async_trait]
31pub trait ResponseProcessor: Send + Sync {
32 fn name(&self) -> &str;
34
35 async fn process_response(&self, response: LlmResponse) -> Result<LlmResponse, ProcessorError>;
37}
38
39pub struct InstructionInserter {
41 instruction: String,
42}
43
44impl InstructionInserter {
45 pub fn new(instruction: impl Into<String>) -> Self {
47 Self {
48 instruction: instruction.into(),
49 }
50 }
51}
52
53#[async_trait]
54impl RequestProcessor for InstructionInserter {
55 fn name(&self) -> &str {
56 "instruction_inserter"
57 }
58
59 async fn process_request(&self, mut request: LlmRequest) -> Result<LlmRequest, ProcessorError> {
60 match &mut request.system_instruction {
61 Some(existing) => {
62 existing.push('\n');
63 existing.push_str(&self.instruction);
64 }
65 None => {
66 request.system_instruction = Some(self.instruction.clone());
67 }
68 }
69 Ok(request)
70 }
71}
72
73pub struct ContentFilter {
75 text_only: bool,
77}
78
79impl ContentFilter {
80 pub fn text_only() -> Self {
82 Self { text_only: true }
83 }
84}
85
86#[async_trait]
87impl RequestProcessor for ContentFilter {
88 fn name(&self) -> &str {
89 "content_filter"
90 }
91
92 async fn process_request(&self, mut request: LlmRequest) -> Result<LlmRequest, ProcessorError> {
93 if self.text_only {
94 for content in &mut request.contents {
95 content
96 .parts
97 .retain(|p| matches!(p, rs_genai::prelude::Part::Text { .. }));
98 }
99 }
100 Ok(request)
101 }
102}
103
104#[derive(Default)]
106pub struct RequestProcessorChain {
107 processors: Vec<Box<dyn RequestProcessor>>,
108}
109
110impl RequestProcessorChain {
111 pub fn new() -> Self {
113 Self::default()
114 }
115
116 pub fn add(&mut self, processor: impl RequestProcessor + 'static) {
118 self.processors.push(Box::new(processor));
119 }
120
121 pub async fn process(&self, mut request: LlmRequest) -> Result<LlmRequest, ProcessorError> {
123 for processor in &self.processors {
124 request = processor.process_request(request).await?;
125 }
126 Ok(request)
127 }
128
129 pub fn len(&self) -> usize {
131 self.processors.len()
132 }
133
134 pub fn is_empty(&self) -> bool {
136 self.processors.is_empty()
137 }
138}
139
140#[derive(Default)]
142pub struct ResponseProcessorChain {
143 processors: Vec<Box<dyn ResponseProcessor>>,
144}
145
146impl ResponseProcessorChain {
147 pub fn new() -> Self {
149 Self::default()
150 }
151
152 pub fn add(&mut self, processor: impl ResponseProcessor + 'static) {
154 self.processors.push(Box::new(processor));
155 }
156
157 pub async fn process(&self, mut response: LlmResponse) -> Result<LlmResponse, ProcessorError> {
159 for processor in &self.processors {
160 response = processor.process_response(response).await?;
161 }
162 Ok(response)
163 }
164
165 pub fn len(&self) -> usize {
167 self.processors.len()
168 }
169
170 pub fn is_empty(&self) -> bool {
172 self.processors.is_empty()
173 }
174}
175
176#[cfg(test)]
177mod tests {
178 use super::*;
179 use crate::llm::LlmRequest;
180
181 #[test]
182 fn request_processor_is_object_safe() {
183 fn _assert(_: &dyn RequestProcessor) {}
184 }
185
186 #[test]
187 fn response_processor_is_object_safe() {
188 fn _assert(_: &dyn ResponseProcessor) {}
189 }
190
191 #[tokio::test]
192 async fn instruction_inserter_sets_instruction() {
193 let inserter = InstructionInserter::new("Be helpful");
194 let req = LlmRequest::from_text("Hello");
195 let processed = inserter.process_request(req).await.unwrap();
196 assert_eq!(processed.system_instruction, Some("Be helpful".into()));
197 }
198
199 #[tokio::test]
200 async fn instruction_inserter_appends_to_existing() {
201 let inserter = InstructionInserter::new("And concise");
202 let mut req = LlmRequest::from_text("Hello");
203 req.system_instruction = Some("Be helpful".into());
204 let processed = inserter.process_request(req).await.unwrap();
205 assert_eq!(
206 processed.system_instruction,
207 Some("Be helpful\nAnd concise".into())
208 );
209 }
210
211 #[tokio::test]
212 async fn content_filter_text_only() {
213 use rs_genai::prelude::{Content, Part, Role};
214
215 let filter = ContentFilter::text_only();
216 let req = LlmRequest {
217 contents: vec![Content {
218 role: Some(Role::User),
219 parts: vec![
220 Part::Text {
221 text: "hello".into(),
222 },
223 Part::InlineData {
224 inline_data: rs_genai::prelude::Blob {
225 mime_type: "image/png".into(),
226 data: "base64data".into(),
227 },
228 },
229 ],
230 }],
231 ..Default::default()
232 };
233 let processed = filter.process_request(req).await.unwrap();
234 assert_eq!(processed.contents[0].parts.len(), 1);
235 assert!(matches!(&processed.contents[0].parts[0], Part::Text { .. }));
236 }
237
238 #[tokio::test]
239 async fn request_processor_chain() {
240 let mut chain = RequestProcessorChain::new();
241 chain.add(InstructionInserter::new("Rule 1"));
242 chain.add(InstructionInserter::new("Rule 2"));
243
244 let req = LlmRequest::from_text("Hello");
245 let processed = chain.process(req).await.unwrap();
246 assert_eq!(processed.system_instruction, Some("Rule 1\nRule 2".into()));
247 }
248
249 #[test]
250 fn chain_len() {
251 let mut chain = RequestProcessorChain::new();
252 assert!(chain.is_empty());
253 chain.add(InstructionInserter::new("x"));
254 assert_eq!(chain.len(), 1);
255 assert!(!chain.is_empty());
256 }
257}