1use async_trait::async_trait;
12use chrono::{DateTime, Utc};
13use serde::{Deserialize, Serialize};
14use std::collections::HashMap;
15use std::sync::Arc;
16use tokio::sync::{Mutex, RwLock};
17use tracing::{debug, error, info};
18
19use crate::errors::{ComputerError, ComputerResult};
20use crate::inputs::handler::InputHandler;
21use crate::inputs::model::InputValue;
22use crate::inputs::utils::run_command;
23use crate::mcp_clients::{
24 manager::MCPServerManager,
25 model::{
26 content_as_text, is_call_tool_error, CallToolResult, Content, MCPServerConfig,
27 MCPServerInput, ReadResourceResult, Resource, Tool,
28 },
29 ConfigRender, RenderError,
30};
31use crate::socketio_client::SmcpComputerClient;
32
33type ConfirmCallbackType = Arc<dyn Fn(&str, &str, &str, &serde_json::Value) -> bool + Send + Sync>;
35
36fn input_value_to_json(value: InputValue) -> serde_json::Value {
38 match value {
39 InputValue::String(s) => serde_json::Value::String(s),
40 InputValue::Number(n) => serde_json::Value::Number(serde_json::Number::from(n)),
41 InputValue::Float(f) => serde_json::Value::Number(
42 serde_json::Number::from_f64(f).unwrap_or(serde_json::Number::from(0)),
43 ),
44 InputValue::Bool(b) => serde_json::Value::Bool(b),
45 }
46}
47
48fn json_to_input_value(value: serde_json::Value) -> ComputerResult<InputValue> {
50 match value {
51 serde_json::Value::String(s) => Ok(InputValue::String(s)),
52 serde_json::Value::Number(n) => {
53 if let Some(i) = n.as_i64() {
54 Ok(InputValue::Number(i))
55 } else if let Some(u) = n.as_u64() {
56 Ok(InputValue::Number(u as i64))
57 } else if let Some(f) = n.as_f64() {
58 Ok(InputValue::Float(f))
59 } else {
60 Err(ComputerError::ValidationError(
61 "Invalid number value".to_string(),
62 ))
63 }
64 }
65 serde_json::Value::Bool(b) => Ok(InputValue::Bool(b)),
66 serde_json::Value::Null => Err(ComputerError::ValidationError(
67 "Null value not supported".to_string(),
68 )),
69 _ => Err(ComputerError::ValidationError(
70 "Unsupported value type".to_string(),
71 )),
72 }
73}
74
75#[derive(Debug, Clone, Serialize, Deserialize)]
77pub struct ToolCallRecord {
78 pub timestamp: DateTime<Utc>,
80 pub req_id: String,
82 pub server: String,
84 pub tool: String,
86 pub parameters: serde_json::Value,
88 pub timeout: Option<f64>,
90 pub success: bool,
92 pub error: Option<String>,
94}
95
96#[async_trait]
99pub trait Session: Send + Sync {
100 async fn resolve_input(&self, input: &MCPServerInput) -> ComputerResult<serde_json::Value>;
102
103 fn session_id(&self) -> &str;
105}
106
107pub struct SilentSession {
109 id: String,
110}
111
112impl SilentSession {
113 pub fn new(id: impl Into<String>) -> Self {
115 Self { id: id.into() }
116 }
117}
118
119#[async_trait]
120impl Session for SilentSession {
121 async fn resolve_input(&self, input: &MCPServerInput) -> ComputerResult<serde_json::Value> {
122 match input {
124 MCPServerInput::PromptString(input) => Ok(serde_json::Value::String(
125 input.default.clone().unwrap_or_default(),
126 )),
127 MCPServerInput::PickString(input) => Ok(serde_json::Value::String(
128 input
129 .default
130 .clone()
131 .unwrap_or_else(|| input.options.first().cloned().unwrap_or_default()),
132 )),
133 MCPServerInput::Command(input) => {
134 let args: Vec<String> = input
136 .args
137 .as_ref()
138 .map(|m| {
139 let mut sorted_pairs: Vec<_> = m.iter().collect();
140 sorted_pairs.sort_by_key(|(k, _)| *k);
141 sorted_pairs.into_iter().map(|(_, v)| v.clone()).collect()
142 })
143 .unwrap_or_default();
144 match run_command(&input.command, &args).await {
145 Ok(output) => Ok(serde_json::Value::String(output)),
146 Err(e) => Err(ComputerError::RuntimeError(format!(
147 "Failed to execute command '{}': {}",
148 input.command, e
149 ))),
150 }
151 }
152 }
153 }
154
155 fn session_id(&self) -> &str {
156 &self.id
157 }
158}
159
160pub struct Computer<S: Session> {
162 name: String,
164 mcp_manager: Arc<RwLock<Option<MCPServerManager>>>,
166 inputs: Arc<RwLock<HashMap<String, MCPServerInput>>>,
170 mcp_servers: RwLock<HashMap<String, MCPServerConfig>>,
172 input_handler: Arc<RwLock<InputHandler>>,
174 auto_connect: bool,
176 auto_reconnect: bool,
178 tool_history: Arc<Mutex<Vec<ToolCallRecord>>>,
180 session: S,
182 socketio_client: Arc<RwLock<Option<Arc<SmcpComputerClient>>>>,
186 confirm_callback: Option<ConfirmCallbackType>,
188}
189
190impl<S: Session> Computer<S> {
191 pub fn new(
193 name: impl Into<String>,
194 session: S,
195 inputs: Option<HashMap<String, MCPServerInput>>,
196 mcp_servers: Option<HashMap<String, MCPServerConfig>>,
197 auto_connect: bool,
198 auto_reconnect: bool,
199 ) -> Self {
200 let name = name.into();
201 let inputs = inputs.unwrap_or_default();
202 let mcp_servers = mcp_servers.unwrap_or_default();
203
204 Self {
205 name,
206 mcp_manager: Arc::new(RwLock::new(None)),
207 inputs: Arc::new(RwLock::new(inputs)),
208 mcp_servers: RwLock::new(mcp_servers),
209 input_handler: Arc::new(RwLock::new(InputHandler::new())),
210 auto_connect,
211 auto_reconnect,
212 tool_history: Arc::new(Mutex::new(Vec::new())),
213 session,
214 socketio_client: Arc::new(RwLock::new(None)),
215 confirm_callback: None,
216 }
217 }
218
219 pub fn with_confirm_callback<F>(mut self, callback: F) -> Self
221 where
222 F: Fn(&str, &str, &str, &serde_json::Value) -> bool + Send + Sync + 'static,
223 {
224 self.confirm_callback = Some(Arc::new(callback));
225 self
226 }
227
228 pub fn name(&self) -> &str {
230 &self.name
231 }
232
233 pub fn get_socketio_client(&self) -> Arc<RwLock<Option<Arc<SmcpComputerClient>>>> {
237 self.socketio_client.clone()
238 }
239
240 pub async fn boot_up(&self) -> ComputerResult<()> {
242 info!("Starting Computer: {}", self.name);
243
244 let manager = MCPServerManager::new();
246
247 let servers = self.mcp_servers.read().await;
249 let mut validated_servers = Vec::new();
250
251 for (_name, server_config) in servers.iter() {
252 match self.render_server_config(server_config).await {
253 Ok(validated) => validated_servers.push(validated),
254 Err(e) => {
255 error!(
256 "Failed to render server config {}: {}",
257 server_config.name(),
258 e
259 );
260 validated_servers.push(server_config.clone());
262 }
263 }
264 }
265
266 manager.initialize(validated_servers).await?;
268
269 *self.mcp_manager.write().await = Some(manager);
271
272 info!("Computer {} started successfully", self.name);
273 Ok(())
274 }
275
276 async fn render_server_config(
280 &self,
281 config: &MCPServerConfig,
282 ) -> ComputerResult<MCPServerConfig> {
283 let config_json = serde_json::to_value(config)?;
285
286 let renderer = ConfigRender::default();
288
289 let inputs = self.inputs.read().await;
291 let inputs_clone: std::collections::HashMap<String, MCPServerInput> = inputs.clone();
292 drop(inputs); let mut resolved_values: std::collections::HashMap<String, serde_json::Value> =
298 std::collections::HashMap::new();
299 for (input_id, input) in inputs_clone.iter() {
300 match self.session.resolve_input(input).await {
301 Ok(value) => {
302 resolved_values.insert(input_id.clone(), value);
303 }
304 Err(e) => {
305 debug!(
306 "Failed to resolve input '{}': {}, will use default",
307 input_id, e
308 );
309 if let Some(default) = input.default() {
311 resolved_values.insert(input_id.clone(), default);
312 }
313 }
314 }
315 }
316
317 let resolver = |input_id: String| {
319 let values = resolved_values.clone();
320 async move {
321 if let Some(value) = values.get(&input_id) {
322 Ok(value.clone())
323 } else {
324 Err(RenderError::InputNotFound(input_id))
325 }
326 }
327 };
328
329 let rendered_json = renderer.render(config_json, resolver).await?;
331
332 let rendered_config: MCPServerConfig = serde_json::from_value(rendered_json)?;
334
335 Ok(rendered_config)
336 }
337
338 pub async fn add_or_update_server(&self, server: MCPServerConfig) -> ComputerResult<()> {
340 {
342 let mut manager_guard = self.mcp_manager.write().await;
343 if manager_guard.is_none() {
344 *manager_guard = Some(MCPServerManager::new());
345 }
346 }
347
348 let validated = self.render_server_config(&server).await?;
350
351 let manager = self.mcp_manager.read().await;
353 if let Some(ref manager) = *manager {
354 manager.add_or_update_server(validated).await?;
355 }
356
357 {
359 let mut servers = self.mcp_servers.write().await;
360 servers.insert(server.name().to_string(), server);
361 }
362
363 let _ = self.emit_update_config().await;
365
366 Ok(())
367 }
368
369 pub async fn remove_server(&self, server_name: &str) -> ComputerResult<()> {
371 let manager = self.mcp_manager.read().await;
372 if let Some(ref manager) = *manager {
373 manager.remove_server(server_name).await?;
374 }
375
376 {
378 let mut servers = self.mcp_servers.write().await;
379 servers.remove(server_name);
380 }
381
382 let _ = self.emit_update_config().await;
384
385 Ok(())
386 }
387
388 pub async fn update_inputs(
390 &self,
391 inputs: HashMap<String, MCPServerInput>,
392 ) -> ComputerResult<()> {
393 *self.inputs.write().await = inputs;
394
395 {
397 let mut input_handler = self.input_handler.write().await;
398 *input_handler = InputHandler::new();
399 }
400
401 let _ = self.emit_update_config().await;
403
404 Ok(())
405 }
406
407 pub async fn add_or_update_input(&self, input: MCPServerInput) -> ComputerResult<()> {
409 let input_id = input.id().to_string();
410 {
411 let mut inputs = self.inputs.write().await;
412 inputs.insert(input_id.clone(), input);
413 }
414
415 self.clear_input_values(Some(&input_id)).await?;
417
418 let _ = self.emit_update_config().await;
420
421 Ok(())
422 }
423
424 pub async fn remove_input(&self, input_id: &str) -> ComputerResult<bool> {
426 let removed = {
427 let mut inputs = self.inputs.write().await;
428 inputs.remove(input_id).is_some()
429 };
430
431 if removed {
432 self.clear_input_values(Some(input_id)).await?;
434
435 let _ = self.emit_update_config().await;
437 }
438
439 Ok(removed)
440 }
441
442 pub async fn get_input(&self, input_id: &str) -> ComputerResult<Option<MCPServerInput>> {
444 let inputs = self.inputs.read().await;
445 Ok(inputs.get(input_id).cloned())
446 }
447
448 pub async fn list_inputs(&self) -> ComputerResult<Vec<MCPServerInput>> {
450 let inputs = self.inputs.read().await;
451 Ok(inputs.values().cloned().collect())
452 }
453
454 pub async fn get_input_value(
456 &self,
457 input_id: &str,
458 ) -> ComputerResult<Option<serde_json::Value>> {
459 let handler = self.input_handler.read().await;
461 let cached_values = handler.get_all_cached_values().await;
462
463 for (key, value) in cached_values {
465 if key.starts_with(input_id) {
468 let parts: Vec<&str> = key.split(':').collect();
470 if !parts.is_empty() && parts[0] == input_id {
471 return Ok(Some(input_value_to_json(value)));
472 }
473 }
474 }
475
476 Ok(None)
477 }
478
479 pub async fn set_input_value(
481 &self,
482 input_id: &str,
483 value: serde_json::Value,
484 ) -> ComputerResult<bool> {
485 {
487 let inputs = self.inputs.read().await;
488 if !inputs.contains_key(input_id) {
489 return Ok(false);
490 }
491 }
492
493 let handler = self.input_handler.read().await;
495 let input_value = json_to_input_value(value)?;
496 handler
497 .set_cached_value(input_id.to_string(), input_value)
498 .await;
499
500 Ok(true)
501 }
502
503 pub async fn remove_input_value(&self, input_id: &str) -> ComputerResult<bool> {
505 let handler = self.input_handler.read().await;
506 let removed = handler.remove_cached_value(input_id).await.is_some();
507 Ok(removed)
508 }
509
510 pub async fn list_input_values(&self) -> ComputerResult<HashMap<String, serde_json::Value>> {
512 let handler = self.input_handler.read().await;
513 let cached_values = handler.get_all_cached_values().await;
514
515 let mut result = HashMap::new();
516 for (key, value) in cached_values {
517 let parts: Vec<&str> = key.split(':').collect();
520 if !parts.is_empty() {
521 result.insert(parts[0].to_string(), input_value_to_json(value));
522 }
523 }
524
525 Ok(result)
526 }
527
528 pub async fn clear_input_values(&self, input_id: Option<&str>) -> ComputerResult<()> {
530 let handler = self.input_handler.read().await;
531
532 if let Some(id) = input_id {
533 let cached_values = handler.get_all_cached_values().await;
535 let keys_to_remove: Vec<String> = cached_values
536 .keys()
537 .filter(|key| key.starts_with(id))
538 .cloned()
539 .collect();
540
541 for key in keys_to_remove {
542 handler.remove_cached_value(&key).await;
543 }
544 } else {
545 handler.clear_all_cache().await;
547 }
548
549 Ok(())
550 }
551
552 pub async fn get_available_tools(&self) -> ComputerResult<Vec<Tool>> {
554 let manager = self.mcp_manager.read().await;
555 if let Some(ref manager) = *manager {
556 let tools: Vec<Tool> = manager.list_available_tools().await;
557 Ok(tools)
561 } else {
562 Err(ComputerError::InvalidState(
563 "Computer not initialized".to_string(),
564 ))
565 }
566 }
567
568 pub async fn list_all_windows(
570 &self,
571 window_uri: Option<&str>,
572 ) -> ComputerResult<Vec<(String, Resource)>> {
573 let manager = self.mcp_manager.read().await;
574 if let Some(ref manager) = *manager {
575 Ok(manager.list_all_windows(window_uri).await)
576 } else {
577 Err(ComputerError::InvalidState(
578 "Computer not initialized".to_string(),
579 ))
580 }
581 }
582
583 pub async fn get_windows_details(
585 &self,
586 window_uri: Option<&str>,
587 ) -> ComputerResult<Vec<(String, Resource, ReadResourceResult)>> {
588 let manager = self.mcp_manager.read().await;
589 if let Some(ref manager) = *manager {
590 Ok(manager.get_windows_details(window_uri).await)
591 } else {
592 Err(ComputerError::InvalidState(
593 "Computer not initialized".to_string(),
594 ))
595 }
596 }
597
598 pub async fn get_window_detail(
600 &self,
601 server_name: &str,
602 resource: Resource,
603 ) -> ComputerResult<ReadResourceResult> {
604 let manager = self.mcp_manager.read().await;
605 if let Some(ref manager) = *manager {
606 manager.get_window_detail(server_name, resource).await
607 } else {
608 Err(ComputerError::InvalidState(
609 "Computer not initialized".to_string(),
610 ))
611 }
612 }
613
614 pub async fn execute_tool(
616 &self,
617 req_id: &str,
618 tool_name: &str,
619 parameters: serde_json::Value,
620 timeout: Option<f64>,
621 ) -> ComputerResult<CallToolResult> {
622 let manager = self.mcp_manager.read().await;
623 if let Some(ref manager) = *manager {
624 let (server_name, tool_name) =
626 manager.validate_tool_call(tool_name, ¶meters).await?;
627 let server_name = server_name.to_string();
628 let tool_name = tool_name.to_string();
629
630 let timestamp = Utc::now();
631 let mut success = false;
632 let mut error_msg = None;
633 let result: CallToolResult;
634
635 let need_confirm = true; let parameters_for_call = parameters.clone();
641
642 if need_confirm {
643 if let Some(ref callback) = self.confirm_callback {
644 let confirmed = callback(req_id, &server_name, &tool_name, ¶meters);
645 if confirmed {
646 let timeout_duration = timeout.map(std::time::Duration::from_secs_f64);
647 result = manager
648 .call_tool(
649 &server_name,
650 &tool_name,
651 parameters_for_call,
652 timeout_duration,
653 )
654 .await?;
655 success = !is_call_tool_error(&result);
656 } else {
657 result = CallToolResult::success(vec![Content::text(
658 "工具调用二次确认被拒绝,请稍后再试",
659 )]);
660 }
661 } else {
662 result = CallToolResult::error(vec![Content::text(
663 "当前工具需要调用前进行二次确认,但客户端目前没有实现二次确认回调方法",
664 )]);
665 error_msg = Some("No confirmation callback".to_string());
666 }
667 } else {
668 let timeout_duration = timeout.map(std::time::Duration::from_secs_f64);
669 result = manager
670 .call_tool(
671 &server_name,
672 &tool_name,
673 parameters_for_call,
674 timeout_duration,
675 )
676 .await?;
677 success = !is_call_tool_error(&result);
678 }
679
680 if is_call_tool_error(&result) {
681 error_msg = result
682 .content
683 .iter()
684 .find_map(|c| content_as_text(c).map(|t| t.to_string()));
685 }
686
687 let record = ToolCallRecord {
689 timestamp,
690 req_id: req_id.to_string(),
691 server: server_name,
692 tool: tool_name,
693 parameters,
694 timeout,
695 success,
696 error: error_msg,
697 };
698
699 {
700 let mut history = self.tool_history.lock().await;
701 history.push(record);
702 if history.len() > 10 {
704 history.remove(0);
705 }
706 }
707
708 Ok(result)
709 } else {
710 Err(ComputerError::InvalidState(
711 "Computer not initialized".to_string(),
712 ))
713 }
714 }
715
716 pub async fn get_tool_history(&self) -> ComputerResult<Vec<ToolCallRecord>> {
718 let history = self.tool_history.lock().await;
719 Ok(history.clone())
720 }
721
722 pub async fn get_server_status(&self) -> Vec<(String, bool, String)> {
724 let manager_guard = self.mcp_manager.read().await;
725 if let Some(ref manager) = *manager_guard {
726 manager.get_server_status().await
727 } else {
728 Vec::new()
729 }
730 }
731
732 pub async fn list_mcp_servers(&self) -> Vec<MCPServerConfig> {
734 let servers = self.mcp_servers.read().await;
735 servers.values().cloned().collect()
736 }
737
738 pub async fn start_mcp_client(&self, server_name: &str) -> ComputerResult<()> {
740 let manager_guard = self.mcp_manager.read().await;
741 if let Some(ref manager) = *manager_guard {
742 if server_name == "all" {
743 manager.start_all().await
744 } else {
745 manager.start_client(server_name).await
746 }
747 } else {
748 Err(ComputerError::InvalidState(
749 "MCP Manager not initialized".to_string(),
750 ))
751 }
752 }
753
754 pub async fn stop_mcp_client(&self, server_name: &str) -> ComputerResult<()> {
756 let manager_guard = self.mcp_manager.read().await;
757 if let Some(ref manager) = *manager_guard {
758 if server_name == "all" {
759 manager.stop_all().await
760 } else {
761 manager.stop_client(server_name).await
762 }
763 } else {
764 Err(ComputerError::InvalidState(
765 "MCP Manager not initialized".to_string(),
766 ))
767 }
768 }
769
770 pub async fn is_mcp_manager_initialized(&self) -> bool {
772 let manager_guard = self.mcp_manager.read().await;
773 manager_guard.is_some()
774 }
775
776 pub async fn set_socketio_client(&self, client: Arc<SmcpComputerClient>) {
780 let mut socketio_ref = self.socketio_client.write().await;
781 *socketio_ref = Some(client);
784 }
785
786 pub async fn connect_socketio(
788 &self,
789 url: &str,
790 _namespace: &str,
791 auth: &Option<String>,
792 _headers: &Option<String>,
793 ) -> ComputerResult<()> {
794 let _manager_check = {
796 let manager_guard = self.mcp_manager.read().await;
797 match manager_guard.as_ref() {
798 Some(_m) => {
799 true
802 }
803 None => {
804 return Err(ComputerError::InvalidState(
805 "MCP Manager not initialized. Please add and start servers first."
806 .to_string(),
807 ));
808 }
809 }
810 };
811
812 let new_manager = MCPServerManager::new();
817
818 let client = SmcpComputerClient::new(
821 url,
822 Arc::new(RwLock::new(Some(new_manager))),
823 self.name.clone(),
824 auth.clone(),
825 self.inputs.clone(),
826 )
827 .await?;
828
829 let client_arc = Arc::new(client);
831 self.set_socketio_client(client_arc.clone()).await;
832
833 info!(
834 "Connected to SMCP server at {} with computer name: {}",
835 url, self.name
836 );
837
838 Ok(())
839 }
840
841 pub async fn disconnect_socketio(&self) -> ComputerResult<()> {
843 let mut socketio_ref = self.socketio_client.write().await;
844 *socketio_ref = None;
845 info!("Disconnected from server");
846 Ok(())
847 }
848
849 pub async fn join_office(&self, office_id: &str, _computer_name: &str) -> ComputerResult<()> {
851 let socketio_ref = self.socketio_client.read().await;
852 if let Some(ref client) = *socketio_ref {
853 client.join_office(office_id).await?;
856 return Ok(());
857 }
858 Err(ComputerError::InvalidState(
859 "Socket.IO client not connected".to_string(),
860 ))
861 }
862
863 pub async fn leave_office(&self) -> ComputerResult<()> {
865 let socketio_ref = self.socketio_client.read().await;
866 if let Some(ref client) = *socketio_ref {
867 let current_office_id = client.get_current_office_id().await?;
870 client.leave_office(¤t_office_id).await?;
871 return Ok(());
872 }
873 Err(ComputerError::InvalidState(
874 "Socket.IO client not connected".to_string(),
875 ))
876 }
877
878 pub async fn emit_update_config(&self) -> ComputerResult<()> {
880 let socketio_ref = self.socketio_client.read().await;
881 if let Some(ref client) = *socketio_ref {
882 client.emit_update_config().await?;
885 return Ok(());
886 }
887 Err(ComputerError::InvalidState(
888 "Socket.IO client not connected".to_string(),
889 ))
890 }
891
892 pub async fn shutdown(&self) -> ComputerResult<()> {
894 info!("Shutting down Computer: {}", self.name);
895
896 let mut manager_guard = self.mcp_manager.write().await;
897 if let Some(manager) = manager_guard.take() {
898 manager.stop_all().await?;
899 }
900
901 {
903 let mut socketio_ref = self.socketio_client.write().await;
904 *socketio_ref = None;
905 }
906
907 info!("Computer {} shutdown successfully", self.name);
908 Ok(())
909 }
910}
911
912impl<S: Session + Clone> Clone for Computer<S> {
914 fn clone(&self) -> Self {
915 Self {
916 name: self.name.clone(),
917 mcp_manager: Arc::clone(&self.mcp_manager),
918 inputs: Arc::new(RwLock::new(HashMap::new())), mcp_servers: RwLock::new(HashMap::new()),
920 input_handler: Arc::clone(&self.input_handler),
921 auto_connect: self.auto_connect,
922 auto_reconnect: self.auto_reconnect,
923 tool_history: Arc::clone(&self.tool_history),
924 session: self.session.clone(),
925 socketio_client: Arc::clone(&self.socketio_client),
926 confirm_callback: self.confirm_callback.clone(),
927 }
928 }
929}
930
931#[async_trait]
933pub trait ManagerChangeHandler: Send + Sync {
934 async fn on_change(&self, message: ManagerChangeMessage) -> ComputerResult<()>;
936}
937
938#[derive(Debug, Clone)]
940pub enum ManagerChangeMessage {
941 ToolListChanged,
943 ResourceListChanged { windows: Vec<String> },
945 ResourceUpdated { uri: String },
947}
948
949#[async_trait]
950impl<S: Session> ManagerChangeHandler for Computer<S> {
951 async fn on_change(&self, message: ManagerChangeMessage) -> ComputerResult<()> {
952 match message {
953 ManagerChangeMessage::ToolListChanged => {
954 debug!("Tool list changed, notifying Socket.IO client");
955 let socketio_ref = self.socketio_client.read().await;
956 if let Some(ref client) = *socketio_ref {
957 client.emit_update_tool_list().await?;
960 }
961 }
962 ManagerChangeMessage::ResourceListChanged { windows: _ } => {
963 debug!("Resource list changed, checking for window updates");
964 }
966 ManagerChangeMessage::ResourceUpdated { uri } => {
967 debug!("Resource updated: {}", uri);
968 }
970 }
971 Ok(())
972 }
973}
974
975#[cfg(test)]
976mod tests {
977 use super::*;
978 use crate::mcp_clients::model::{
979 CommandInput, MCPServerConfig, MCPServerInput, PickStringInput, PromptStringInput,
980 StdioServerConfig, StdioServerParameters,
981 };
982
983 #[tokio::test]
984 async fn test_computer_creation() {
985 let session = SilentSession::new("test");
986 let computer = Computer::new("test_computer", session, None, None, true, true);
987
988 assert_eq!(computer.name, "test_computer");
989 assert!(computer.auto_connect);
990 assert!(computer.auto_reconnect);
991 }
992
993 #[tokio::test]
994 async fn test_computer_with_initial_inputs_and_servers() {
995 let session = SilentSession::new("test");
996 let mut inputs = HashMap::new();
997 inputs.insert(
998 "input1".to_string(),
999 MCPServerInput::PromptString(PromptStringInput {
1000 id: "input1".to_string(),
1001 description: "Test input".to_string(),
1002 default: Some("default".to_string()),
1003 password: Some(false),
1004 }),
1005 );
1006
1007 let mut servers = HashMap::new();
1008 servers.insert(
1009 "server1".to_string(),
1010 MCPServerConfig::Stdio(StdioServerConfig {
1011 name: "server1".to_string(),
1012 disabled: false,
1013 forbidden_tools: vec![],
1014 tool_meta: std::collections::HashMap::new(),
1015 default_tool_meta: None,
1016 vrl: None,
1017 server_parameters: StdioServerParameters {
1018 command: "echo".to_string(),
1019 args: vec![],
1020 env: std::collections::HashMap::new(),
1021 cwd: None,
1022 },
1023 }),
1024 );
1025
1026 let computer = Computer::new(
1027 "test_computer",
1028 session,
1029 Some(inputs),
1030 Some(servers),
1031 false,
1032 false,
1033 );
1034
1035 let inputs = computer.list_inputs().await.unwrap();
1037 assert_eq!(inputs.len(), 1);
1038 match &inputs[0] {
1039 MCPServerInput::PromptString(input) => {
1040 assert_eq!(input.id, "input1");
1041 assert_eq!(input.description, "Test input");
1042 }
1043 _ => panic!("Expected PromptString input"),
1044 }
1045 }
1046
1047 #[tokio::test]
1048 async fn test_input_management() {
1049 let session = SilentSession::new("test");
1050 let computer = Computer::new("test_computer", session, None, None, true, true);
1051
1052 let input = MCPServerInput::PromptString(PromptStringInput {
1054 id: "test_input".to_string(),
1055 description: "Test input".to_string(),
1056 default: Some("default".to_string()),
1057 password: Some(false),
1058 });
1059
1060 computer.add_or_update_input(input.clone()).await.unwrap();
1061
1062 let retrieved = computer.get_input("test_input").await.unwrap();
1064 assert!(retrieved.is_some());
1065
1066 let inputs = computer.list_inputs().await.unwrap();
1068 assert_eq!(inputs.len(), 1);
1069
1070 let updated_input = MCPServerInput::PromptString(PromptStringInput {
1072 id: "test_input".to_string(),
1073 description: "Updated description".to_string(),
1074 default: Some("new_default".to_string()),
1075 password: Some(true),
1076 });
1077 computer.add_or_update_input(updated_input).await.unwrap();
1078
1079 let retrieved = computer.get_input("test_input").await.unwrap().unwrap();
1080 match retrieved {
1081 MCPServerInput::PromptString(input) => {
1082 assert_eq!(input.description, "Updated description");
1083 assert_eq!(input.default, Some("new_default".to_string()));
1084 assert_eq!(input.password, Some(true));
1085 }
1086 _ => panic!("Expected PromptString input"),
1087 }
1088
1089 let removed = computer.remove_input("test_input").await.unwrap();
1091 assert!(removed);
1092
1093 let retrieved = computer.get_input("test_input").await.unwrap();
1094 assert!(retrieved.is_none());
1095
1096 let removed = computer.remove_input("non_existent").await.unwrap();
1098 assert!(!removed);
1099 }
1100
1101 #[tokio::test]
1102 async fn test_multiple_input_types() {
1103 let session = SilentSession::new("test");
1104 let computer = Computer::new("test_computer", session, None, None, true, true);
1105
1106 let prompt_input = MCPServerInput::PromptString(PromptStringInput {
1108 id: "prompt".to_string(),
1109 description: "Prompt input".to_string(),
1110 default: None,
1111 password: Some(false),
1112 });
1113
1114 let pick_input = MCPServerInput::PickString(PickStringInput {
1115 id: "pick".to_string(),
1116 description: "Pick input".to_string(),
1117 options: vec!["option1".to_string(), "option2".to_string()],
1118 default: Some("option1".to_string()),
1119 });
1120
1121 let command_input = MCPServerInput::Command(CommandInput {
1122 id: "command".to_string(),
1123 description: "Command input".to_string(),
1124 command: "ls".to_string(),
1125 args: None,
1126 });
1127
1128 computer.add_or_update_input(prompt_input).await.unwrap();
1129 computer.add_or_update_input(pick_input).await.unwrap();
1130 computer.add_or_update_input(command_input).await.unwrap();
1131
1132 let inputs = computer.list_inputs().await.unwrap();
1133 assert_eq!(inputs.len(), 3);
1134
1135 let input_types: std::collections::HashSet<_> = inputs
1137 .iter()
1138 .map(|input| match input {
1139 MCPServerInput::PromptString(_) => "prompt",
1140 MCPServerInput::PickString(_) => "pick",
1141 MCPServerInput::Command(_) => "command",
1142 })
1143 .collect();
1144
1145 assert!(input_types.contains("prompt"));
1146 assert!(input_types.contains("pick"));
1147 assert!(input_types.contains("command"));
1148 }
1149
1150 #[tokio::test]
1151 async fn test_server_management() {
1152 let session = SilentSession::new("test");
1153 let computer = Computer::new("test_computer", session, None, None, true, true);
1154
1155 let server_config = MCPServerConfig::Stdio(StdioServerConfig {
1157 name: "test_server".to_string(),
1158 disabled: false,
1159 forbidden_tools: vec![],
1160 tool_meta: std::collections::HashMap::new(),
1161 default_tool_meta: None,
1162 vrl: None,
1163 server_parameters: StdioServerParameters {
1164 command: "echo".to_string(),
1165 args: vec!["hello".to_string()],
1166 env: std::collections::HashMap::new(),
1167 cwd: None,
1168 },
1169 });
1170
1171 computer
1172 .add_or_update_server(server_config.clone())
1173 .await
1174 .unwrap();
1175
1176 let updated_config = MCPServerConfig::Stdio(StdioServerConfig {
1179 name: "test_server".to_string(),
1180 disabled: true, forbidden_tools: vec!["tool1".to_string()],
1182 tool_meta: std::collections::HashMap::new(),
1183 default_tool_meta: None,
1184 vrl: None,
1185 server_parameters: StdioServerParameters {
1186 command: "echo".to_string(),
1187 args: vec!["updated".to_string()],
1188 env: std::collections::HashMap::new(),
1189 cwd: None,
1190 },
1191 });
1192
1193 computer.add_or_update_server(updated_config).await.unwrap();
1194
1195 computer.remove_server("test_server").await.unwrap();
1197 }
1198
1199 #[tokio::test]
1200 async fn test_session_trait() {
1201 let session = SilentSession::new("test_session");
1203 assert_eq!(session.session_id(), "test_session");
1204
1205 let prompt_input = MCPServerInput::PromptString(PromptStringInput {
1207 id: "test".to_string(),
1208 description: "Test".to_string(),
1209 default: Some("default_value".to_string()),
1210 password: Some(false),
1211 });
1212
1213 let result = session.resolve_input(&prompt_input).await.unwrap();
1214 assert_eq!(
1215 result,
1216 serde_json::Value::String("default_value".to_string())
1217 );
1218
1219 let no_default_input = MCPServerInput::PromptString(PromptStringInput {
1221 id: "test2".to_string(),
1222 description: "Test2".to_string(),
1223 default: None,
1224 password: Some(false),
1225 });
1226
1227 let result = session.resolve_input(&no_default_input).await.unwrap();
1228 assert_eq!(result, serde_json::Value::String("".to_string()));
1229
1230 let pick_input = MCPServerInput::PickString(PickStringInput {
1232 id: "pick".to_string(),
1233 description: "Pick".to_string(),
1234 options: vec!["opt1".to_string(), "opt2".to_string()],
1235 default: Some("opt2".to_string()),
1236 });
1237
1238 let result = session.resolve_input(&pick_input).await.unwrap();
1239 assert_eq!(result, serde_json::Value::String("opt2".to_string()));
1240
1241 let command_input = MCPServerInput::Command(CommandInput {
1243 id: "cmd".to_string(),
1244 description: "Command".to_string(),
1245 command: "echo hello world".to_string(),
1246 args: None,
1247 });
1248
1249 let result = session.resolve_input(&command_input).await.unwrap();
1250 assert_eq!(result, serde_json::Value::String("hello world".to_string()));
1251 }
1252
1253 #[tokio::test]
1254 async fn test_cache_operations() {
1255 let session = SilentSession::new("test");
1256 let computer = Computer::new("test_computer", session, None, None, true, true);
1257
1258 let input = MCPServerInput::PromptString(PromptStringInput {
1260 id: "test_input".to_string(),
1261 description: "Test input".to_string(),
1262 default: Some("default".to_string()),
1263 password: Some(false),
1264 });
1265 computer.add_or_update_input(input).await.unwrap();
1266
1267 let test_value = serde_json::Value::String("cached_value".to_string());
1269 let set_result = computer
1270 .set_input_value("test_input", test_value.clone())
1271 .await
1272 .unwrap();
1273 assert!(set_result);
1274
1275 let retrieved = computer.get_input_value("test_input").await.unwrap();
1276 assert_eq!(retrieved, Some(test_value));
1277
1278 let invalid_result = computer
1280 .set_input_value(
1281 "nonexistent",
1282 serde_json::Value::String("value".to_string()),
1283 )
1284 .await
1285 .unwrap();
1286 assert!(!invalid_result);
1287
1288 let not_found = computer.get_input_value("nonexistent").await.unwrap();
1290 assert!(not_found.is_none());
1291 }
1292
1293 #[tokio::test]
1294 async fn test_cache_remove_and_clear() {
1295 let session = SilentSession::new("test");
1296 let computer = Computer::new("test_computer", session, None, None, true, true);
1297
1298 let input1 = MCPServerInput::PromptString(PromptStringInput {
1300 id: "input1".to_string(),
1301 description: "Input 1".to_string(),
1302 default: None,
1303 password: Some(false),
1304 });
1305 let input2 = MCPServerInput::PromptString(PromptStringInput {
1306 id: "input2".to_string(),
1307 description: "Input 2".to_string(),
1308 default: None,
1309 password: Some(false),
1310 });
1311 computer.add_or_update_input(input1).await.unwrap();
1312 computer.add_or_update_input(input2).await.unwrap();
1313
1314 computer
1316 .set_input_value("input1", serde_json::Value::String("value1".to_string()))
1317 .await
1318 .unwrap();
1319 computer
1320 .set_input_value("input2", serde_json::Value::String("value2".to_string()))
1321 .await
1322 .unwrap();
1323
1324 let removed = computer.remove_input_value("input1").await.unwrap();
1326 assert!(removed);
1327
1328 let retrieved = computer.get_input_value("input1").await.unwrap();
1329 assert!(retrieved.is_none());
1330
1331 let still_exists = computer.get_input_value("input2").await.unwrap();
1332 assert!(still_exists.is_some());
1333
1334 computer.clear_input_values(None).await.unwrap();
1336 let cleared1 = computer.get_input_value("input1").await.unwrap();
1337 let cleared2 = computer.get_input_value("input2").await.unwrap();
1338 assert!(cleared1.is_none());
1339 assert!(cleared2.is_none());
1340 }
1341
1342 #[tokio::test]
1343 async fn test_cache_list_values() {
1344 let session = SilentSession::new("test");
1345 let computer = Computer::new("test_computer", session, None, None, true, true);
1346
1347 let input1 = MCPServerInput::PromptString(PromptStringInput {
1349 id: "input1".to_string(),
1350 description: "Input 1".to_string(),
1351 default: None,
1352 password: Some(false),
1353 });
1354 let input2 = MCPServerInput::PromptString(PromptStringInput {
1355 id: "input2".to_string(),
1356 description: "Input 2".to_string(),
1357 default: None,
1358 password: Some(false),
1359 });
1360 computer.add_or_update_input(input1).await.unwrap();
1361 computer.add_or_update_input(input2).await.unwrap();
1362
1363 computer
1365 .set_input_value(
1366 "input1",
1367 serde_json::Value::String("string_value".to_string()),
1368 )
1369 .await
1370 .unwrap();
1371 computer
1372 .set_input_value(
1373 "input2",
1374 serde_json::Value::Number(serde_json::Number::from(42)),
1375 )
1376 .await
1377 .unwrap();
1378
1379 let values = computer.list_input_values().await.unwrap();
1381 assert_eq!(values.len(), 2);
1382 assert_eq!(
1383 values.get("input1"),
1384 Some(&serde_json::Value::String("string_value".to_string()))
1385 );
1386 assert_eq!(
1387 values.get("input2"),
1388 Some(&serde_json::Value::Number(serde_json::Number::from(42)))
1389 );
1390 }
1391
1392 #[tokio::test]
1393 async fn test_cache_clear_on_input_update() {
1394 let session = SilentSession::new("test");
1395 let computer = Computer::new("test_computer", session, None, None, true, true);
1396
1397 let input = MCPServerInput::PromptString(PromptStringInput {
1399 id: "test_input".to_string(),
1400 description: "Test input".to_string(),
1401 default: None,
1402 password: Some(false),
1403 });
1404 computer.add_or_update_input(input).await.unwrap();
1405
1406 computer
1408 .set_input_value(
1409 "test_input",
1410 serde_json::Value::String("cached".to_string()),
1411 )
1412 .await
1413 .unwrap();
1414 assert!(computer
1415 .get_input_value("test_input")
1416 .await
1417 .unwrap()
1418 .is_some());
1419
1420 let updated_input = MCPServerInput::PromptString(PromptStringInput {
1422 id: "test_input".to_string(),
1423 description: "Updated input".to_string(),
1424 default: Some("new_default".to_string()),
1425 password: Some(true),
1426 });
1427 computer.add_or_update_input(updated_input).await.unwrap();
1428
1429 assert!(computer
1431 .get_input_value("test_input")
1432 .await
1433 .unwrap()
1434 .is_none());
1435 }
1436
1437 #[tokio::test]
1438 async fn test_cache_clear_on_input_remove() {
1439 let session = SilentSession::new("test");
1440 let computer = Computer::new("test_computer", session, None, None, true, true);
1441
1442 let input = MCPServerInput::PromptString(PromptStringInput {
1444 id: "test_input".to_string(),
1445 description: "Test input".to_string(),
1446 default: None,
1447 password: Some(false),
1448 });
1449 computer.add_or_update_input(input).await.unwrap();
1450
1451 computer
1453 .set_input_value(
1454 "test_input",
1455 serde_json::Value::String("cached".to_string()),
1456 )
1457 .await
1458 .unwrap();
1459 assert!(computer
1460 .get_input_value("test_input")
1461 .await
1462 .unwrap()
1463 .is_some());
1464
1465 let removed = computer.remove_input("test_input").await.unwrap();
1467 assert!(removed);
1468
1469 assert!(computer
1471 .get_input_value("test_input")
1472 .await
1473 .unwrap()
1474 .is_none());
1475 }
1476
1477 #[tokio::test]
1478 async fn test_tool_call_history() {
1479 let session = SilentSession::new("test");
1480 let computer = Computer::new("test_computer", session, None, None, true, true);
1481
1482 let history = computer.get_tool_history().await.unwrap();
1484 assert!(history.is_empty());
1485
1486 }
1489
1490 #[tokio::test]
1491 async fn test_confirmation_callback() {
1492 let session = SilentSession::new("test");
1493 let computer = Computer::new("test_computer", session, None, None, true, true);
1494
1495 let callback_called = Arc::new(Mutex::new(false));
1497 let callback_called_clone = callback_called.clone();
1498
1499 let _computer = computer.with_confirm_callback(move |_req_id, _server, _tool, _params| {
1500 let rt = tokio::runtime::Handle::current();
1503 rt.block_on(async {
1504 let mut called = callback_called_clone.lock().await;
1505 *called = true;
1506 });
1507 true });
1509
1510 }
1513
1514 #[tokio::test]
1515 async fn test_computer_shutdown() {
1516 let session = SilentSession::new("test");
1517 let computer = Computer::new("test_computer", session, None, None, true, true);
1518
1519 computer.shutdown().await.unwrap();
1521
1522 computer.boot_up().await.unwrap();
1524 computer.shutdown().await.unwrap();
1525 }
1526
1527 #[tokio::test]
1528 async fn test_config_render() {
1529 let session = SilentSession::new("test");
1530
1531 let mut inputs = HashMap::new();
1533 inputs.insert(
1534 "api_key".to_string(),
1535 MCPServerInput::PromptString(PromptStringInput {
1536 id: "api_key".to_string(),
1537 description: "API Key".to_string(),
1538 default: Some("test-api-key-12345".to_string()),
1539 password: Some(true),
1540 }),
1541 );
1542 inputs.insert(
1543 "server_url".to_string(),
1544 MCPServerInput::PromptString(PromptStringInput {
1545 id: "server_url".to_string(),
1546 description: "Server URL".to_string(),
1547 default: Some("https://api.example.com".to_string()),
1548 password: Some(false),
1549 }),
1550 );
1551
1552 let computer = Computer::new("test_computer", session, Some(inputs), None, true, true);
1553
1554 let server_config = MCPServerConfig::Stdio(StdioServerConfig {
1556 name: "test_server".to_string(),
1557 disabled: false,
1558 forbidden_tools: vec![],
1559 tool_meta: std::collections::HashMap::new(),
1560 default_tool_meta: None,
1561 vrl: None,
1562 server_parameters: StdioServerParameters {
1563 command: "echo".to_string(),
1564 args: vec!["${input:api_key}".to_string()],
1565 env: {
1566 let mut env = std::collections::HashMap::new();
1567 env.insert("API_URL".to_string(), "${input:server_url}".to_string());
1568 env
1569 },
1570 cwd: None,
1571 },
1572 });
1573
1574 let rendered = computer.render_server_config(&server_config).await.unwrap();
1576
1577 match rendered {
1579 MCPServerConfig::Stdio(config) => {
1580 assert_eq!(config.server_parameters.args[0], "test-api-key-12345");
1581 assert_eq!(
1582 config.server_parameters.env.get("API_URL"),
1583 Some(&"https://api.example.com".to_string())
1584 );
1585 }
1586 _ => panic!("Expected Stdio config"),
1587 }
1588 }
1589
1590 #[tokio::test]
1591 async fn test_config_render_missing_input() {
1592 let session = SilentSession::new("test");
1593 let computer = Computer::new("test_computer", session, None, None, true, true);
1594
1595 let server_config = MCPServerConfig::Stdio(StdioServerConfig {
1597 name: "test_server".to_string(),
1598 disabled: false,
1599 forbidden_tools: vec![],
1600 tool_meta: std::collections::HashMap::new(),
1601 default_tool_meta: None,
1602 vrl: None,
1603 server_parameters: StdioServerParameters {
1604 command: "echo".to_string(),
1605 args: vec!["${input:missing_input}".to_string()],
1606 env: std::collections::HashMap::new(),
1607 cwd: None,
1608 },
1609 });
1610
1611 let rendered = computer.render_server_config(&server_config).await.unwrap();
1613
1614 match rendered {
1615 MCPServerConfig::Stdio(config) => {
1616 assert_eq!(config.server_parameters.args[0], "${input:missing_input}");
1618 }
1619 _ => panic!("Expected Stdio config"),
1620 }
1621 }
1622}