Skip to main content

smcp_computer/inputs/
handler.rs

1/**
2* 文件名: handler
3* 作者: JQQ
4* 创建日期: 2025/12/15
5* 最后修改日期: 2025/12/15
6* 版权: 2023 JQQ. All rights reserved.
7* 依赖: tokio, async-trait
8* 描述: 输入处理器,负责协调各种输入提供者
9*/
10use super::model::*;
11use super::providers::{
12    CliInputProvider, CompositeInputProvider, EnvironmentInputProvider, InputProvider,
13};
14use std::collections::HashMap;
15use std::sync::Arc;
16use tokio::sync::RwLock;
17use tracing::{debug, error, info};
18
19/// 输入处理器 / Input handler
20pub struct InputHandler {
21    /// 输入提供者 / Input provider
22    provider: Arc<dyn InputProvider>,
23    /// 缓存的输入值 / Cached input values
24    cache: Arc<RwLock<HashMap<String, InputValue>>>,
25    /// 是否启用缓存 / Whether to enable cache
26    enable_cache: bool,
27}
28
29impl InputHandler {
30    /// 创建新的输入处理器 / Create new input handler
31    pub fn new() -> Self {
32        // 默认使用组合提供者:先尝试环境变量,再尝试CLI
33        // Use composite provider by default: try environment variable first, then CLI
34        let provider: Box<dyn InputProvider> = Box::new(
35            CompositeInputProvider::new()
36                .add_provider(Box::new(EnvironmentInputProvider::new()))
37                .add_provider(Box::new(CliInputProvider::new())),
38        );
39
40        Self {
41            provider: Arc::from(provider),
42            cache: Arc::new(RwLock::new(HashMap::new())),
43            enable_cache: true,
44        }
45    }
46
47    /// 使用自定义提供者创建输入处理器 / Create input handler with custom provider
48    pub fn with_provider<P>(provider: P) -> Self
49    where
50        P: InputProvider + 'static,
51    {
52        Self {
53            provider: Arc::new(provider),
54            cache: Arc::new(RwLock::new(HashMap::new())),
55            enable_cache: true,
56        }
57    }
58
59    /// 设置是否启用缓存 / Set whether to enable cache
60    pub fn with_cache(mut self, enable: bool) -> Self {
61        self.enable_cache = enable;
62        self
63    }
64
65    /// 获取单个输入 / Get single input
66    pub async fn get_input(
67        &self,
68        request: InputRequest,
69        context: InputContext,
70    ) -> InputResult<InputResponse> {
71        debug!("Getting input for: {} (context: {:?})", request.id, context);
72
73        // 检查缓存 / Check cache
74        if self.enable_cache {
75            let cache_key = self.build_cache_key(&request.id, &context);
76            if let Some(value) = self.get_cached_value(&cache_key).await {
77                debug!("Using cached value for: {}", request.id);
78                return Ok(InputResponse {
79                    id: request.id,
80                    value,
81                    cancelled: false,
82                });
83            }
84        }
85
86        // 从提供者获取输入 / Get input from provider
87        let mut response = self.provider.get_input(&request, &context).await;
88
89        // 如果获取失败且有默认值,返回默认值
90        // If failed and has default value, return default value
91        if response.is_err() && !request.required {
92            if let Some(default) = &request.default {
93                info!("Using default value for: {}", request.id);
94                response = Ok(InputResponse {
95                    id: request.id.clone(),
96                    value: default.clone(),
97                    cancelled: false,
98                });
99            }
100        }
101
102        // 缓存结果 / Cache result
103        if self.enable_cache {
104            if let Ok(ref resp) = response {
105                if !resp.cancelled {
106                    let cache_key = self.build_cache_key(&request.id, &context);
107                    self.cache_value(cache_key, resp.value.clone()).await;
108                }
109            }
110        }
111
112        response
113    }
114
115    /// 批量获取输入 / Get multiple inputs
116    pub async fn get_inputs(
117        &self,
118        requests: Vec<InputRequest>,
119        context: InputContext,
120    ) -> InputResult<Vec<InputResponse>> {
121        let mut responses = Vec::new();
122
123        for request in requests {
124            match self.get_input(request, context.clone()).await {
125                Ok(response) => responses.push(response),
126                Err(e) => {
127                    error!("Failed to get input: {}", e);
128                    return Err(e);
129                }
130            }
131        }
132
133        Ok(responses)
134    }
135
136    /// 清除缓存 / Clear cache
137    pub async fn clear_cache(&self) {
138        self.cache.write().await.clear();
139        debug!("Input cache cleared");
140    }
141
142    /// 清除特定缓存 / Clear specific cache
143    pub async fn clear_cache_for(&self, id: &str, context: &InputContext) {
144        let cache_key = self.build_cache_key(id, context);
145        let mut cache = self.cache.write().await;
146        cache.remove(&cache_key);
147        debug!("Cleared cache for: {}", id);
148    }
149
150    /// 构建缓存键 / Build cache key
151    fn build_cache_key(&self, id: &str, context: &InputContext) -> String {
152        let mut key = id.to_string();
153
154        if let Some(server) = &context.server_name {
155            key = format!("{}:{}", key, server);
156        }
157
158        if let Some(tool) = &context.tool_name {
159            key = format!("{}:{}", key, tool);
160        }
161
162        // 添加其他元数据 / Add other metadata
163        if !context.metadata.is_empty() {
164            let mut metadata_pairs: Vec<_> = context.metadata.iter().collect();
165            metadata_pairs.sort_by(|(k1, _), (k2, _)| k1.cmp(k2)); // 确保顺序一致 / Ensure consistent order
166
167            for (k, v) in metadata_pairs {
168                key = format!("{}:{}={}", key, k, v);
169            }
170        }
171
172        key
173    }
174
175    /// 获取缓存值 / Get cached value
176    async fn get_cached_value(&self, key: &str) -> Option<InputValue> {
177        let cache = self.cache.read().await;
178        cache.get(key).cloned()
179    }
180
181    /// 缓存值 / Cache value
182    async fn cache_value(&self, key: String, value: InputValue) {
183        let mut cache = self.cache.write().await;
184        cache.insert(key, value);
185    }
186
187    /// 获取所有缓存值 / Get all cached values
188    pub async fn get_all_cached_values(&self) -> HashMap<String, InputValue> {
189        let cache = self.cache.read().await;
190        cache.clone()
191    }
192
193    /// 设置缓存值 / Set cached value
194    pub async fn set_cached_value(&self, key: String, value: InputValue) {
195        self.cache_value(key, value).await;
196    }
197
198    /// 删除缓存值 / Remove cached value
199    pub async fn remove_cached_value(&self, key: &str) -> Option<InputValue> {
200        let mut cache = self.cache.write().await;
201        cache.remove(key)
202    }
203
204    /// 清空所有缓存 / Clear all cache
205    pub async fn clear_all_cache(&self) {
206        self.cache.write().await.clear();
207    }
208
209    /// 从MCP服务器输入配置创建请求 / Create request from MCP server input configuration
210    pub fn create_request_from_mcp_input(
211        &self,
212        mcp_input: &crate::mcp_clients::model::MCPServerInput,
213        default: Option<InputValue>,
214    ) -> InputRequest {
215        match mcp_input {
216            crate::mcp_clients::model::MCPServerInput::PromptString(input) => InputRequest {
217                id: input.id.clone(),
218                input_type: InputType::String {
219                    password: input.password,
220                    min_length: None,
221                    max_length: None,
222                },
223                title: input.description.clone(),
224                description: input.description.clone(),
225                default,
226                required: true,
227                validation: None,
228            },
229            crate::mcp_clients::model::MCPServerInput::PickString(input) => InputRequest {
230                id: input.id.clone(),
231                input_type: InputType::PickString {
232                    options: input.options.clone(),
233                    multiple: false,
234                },
235                title: input.description.clone(),
236                description: input.description.clone(),
237                default,
238                required: true,
239                validation: None,
240            },
241            crate::mcp_clients::model::MCPServerInput::Command(input) => InputRequest {
242                id: input.id.clone(),
243                input_type: InputType::Command {
244                    command: input.command.clone(),
245                    args: input
246                        .args
247                        .as_ref()
248                        .map(|m| {
249                            let mut sorted_pairs: Vec<_> = m.iter().collect();
250                            sorted_pairs.sort_by_key(|(k, _)| *k);
251                            sorted_pairs.into_iter().map(|(_, v)| v.clone()).collect()
252                        })
253                        .unwrap_or_default(),
254                },
255                title: input.description.clone(),
256                description: input.description.clone(),
257                default,
258                required: true,
259                validation: None,
260            },
261        }
262    }
263
264    /// 处理MCP服务器输入 / Handle MCP server inputs
265    pub async fn handle_mcp_inputs(
266        &self,
267        inputs: &[crate::mcp_clients::model::MCPServerInput],
268        context: InputContext,
269    ) -> InputResult<HashMap<String, InputValue>> {
270        let mut results = HashMap::new();
271        let mut requests = Vec::new();
272
273        // 创建请求 / Create requests
274        for input in inputs {
275            let request = self.create_request_from_mcp_input(input, None);
276            requests.push(request);
277        }
278
279        // 获取输入 / Get inputs
280        let responses = self.get_inputs(requests, context).await?;
281
282        // 收集结果 / Collect results
283        for response in responses {
284            results.insert(response.id, response.value);
285        }
286
287        Ok(results)
288    }
289}
290
291impl Default for InputHandler {
292    fn default() -> Self {
293        Self::new()
294    }
295}
296
297#[cfg(test)]
298mod tests {
299    use super::*;
300    use crate::mcp_clients::model::*;
301
302    #[tokio::test]
303    async fn test_input_handler_creation() {
304        let handler = InputHandler::new();
305        assert!(handler.enable_cache);
306    }
307
308    #[tokio::test]
309    async fn test_cache_key_generation() {
310        let handler = InputHandler::new();
311        let context = InputContext::new()
312            .with_server_name("test_server".to_string())
313            .with_tool_name("test_tool".to_string());
314
315        let key = handler.build_cache_key("test_input", &context);
316        assert_eq!(key, "test_input:test_server:test_tool");
317    }
318
319    #[tokio::test]
320    async fn test_cache_operations() {
321        let handler = InputHandler::new();
322        let _context = InputContext::new();
323
324        // 测试缓存设置和获取 / Test cache set and get
325        let key = "test_key";
326        let value = InputValue::String("test_value".to_string());
327
328        handler.cache_value(key.to_string(), value.clone()).await;
329        let cached = handler.get_cached_value(key).await;
330
331        assert_eq!(cached, Some(value));
332    }
333
334    #[tokio::test]
335    async fn test_create_request_from_mcp_input() {
336        let handler = InputHandler::new();
337
338        let mcp_input = MCPServerInput::PromptString(PromptStringInput {
339            id: "test_input".to_string(),
340            description: "Test input".to_string(),
341            default: Some("default".to_string()),
342            password: Some(false),
343        });
344
345        let request = handler.create_request_from_mcp_input(&mcp_input, None);
346
347        assert_eq!(request.id, "test_input");
348        assert_eq!(request.title, "Test input");
349        assert_eq!(request.description, "Test input");
350        assert!(request.required);
351
352        match request.input_type {
353            InputType::String { password, .. } => {
354                assert_eq!(password, Some(false));
355            }
356            _ => panic!("Expected string input type"),
357        }
358    }
359}