Skip to main content

ares/mcp/
client.rs

1use serde::Deserialize;
2use serde_json::Value;
3
4#[derive(Debug, Clone, Deserialize)]
5pub struct McpServerConfig {
6  pub name: String,
7  pub enabled: bool,
8  pub command: Option<String>,
9  pub args: Option<Vec<String>>,
10  pub timeout_secs: Option<u64>,
11  pub endpoint: Option<String>,
12  pub transport: Option<String>,
13  pub api_key: Option<String>,
14}
15
16pub struct McpClient {
17  config: McpServerConfig,
18  http: reqwest::Client,
19}
20
21#[derive(Debug, thiserror::Error)]
22pub enum McpError {
23  #[error("HTTP request failed: {0}")]
24  Http(#[from] reqwest::Error),
25  #[error("MCP server returned error: {0}")]
26  ServerError(String),
27  #[error("Tool not found: {0}")]
28  ToolNotFound(String),
29  #[error("Deserialize error: {0}")]
30  Deserialize(#[from] serde_json::Error),
31  #[error("MCP server is disabled")]
32  ServerDisabled,
33  #[error("No endpoint configured")]
34  NoEndpoint,
35}
36
37impl McpClient {
38  pub fn new(config: McpServerConfig) -> Self {
39    let http = reqwest::Client::builder()
40      .timeout(std::time::Duration::from_secs(
41        config.timeout_secs.unwrap_or(30),
42      ))
43      .build()
44      .expect("Failed to create HTTP client");
45
46    Self { config, http }
47  }
48
49  pub fn is_enabled(&self) -> bool {
50    self.config.enabled
51  }
52
53  pub fn name(&self) -> &str {
54    &self.config.name
55  }
56
57  pub async fn get_context(&self, path: &str) -> Result<Value, McpError> {
58    let base_url = self.get_base_url()?;
59    let url = format!("{}/api/v1/context?path={}", base_url, path);
60
61    let mut request = self.http.get(&url);
62    if let Some(ref key) = self.config.api_key {
63      request = request.header("Authorization", format!("Bearer {}", key));
64    }
65
66    let response = request.send().await?;
67    self.handle_response(response).await
68  }
69
70  pub async fn write_context(&self, path: &str, value: &str) -> Result<Value, McpError> {
71    let base_url = self.get_base_url()?;
72    let url = format!("{}/api/v1/context", base_url);
73
74    let body = serde_json::json!({
75      "path": path,
76      "value": value
77    });
78
79    let mut request = self.http.post(&url).json(&body);
80    if let Some(ref key) = self.config.api_key {
81      request = request.header("Authorization", format!("Bearer {}", key));
82    }
83
84    let response = request.send().await?;
85    self.handle_response(response).await
86  }
87
88  pub async fn search_context(
89    &self,
90    query: &str,
91    scope: Option<&str>,
92    max_results: Option<usize>,
93  ) -> Result<Value, McpError> {
94    let base_url = self.get_base_url()?;
95    let url = format!("{}/api/v1/context/search", base_url);
96
97    let mut body = serde_json::json!({
98      "query": query
99    });
100    if let Some(s) = scope {
101      body["scope"] = serde_json::json!(s);
102    }
103    if let Some(m) = max_results {
104      body["max_results"] = serde_json::json!(m);
105    }
106
107    let mut request = self.http.post(&url).json(&body);
108    if let Some(ref key) = self.config.api_key {
109      request = request.header("Authorization", format!("Bearer {}", key));
110    }
111
112    let response = request.send().await?;
113    self.handle_response(response).await
114  }
115
116  pub async fn get_completeness(&self, scope: Option<&str>) -> Result<Value, McpError> {
117    let base_url = self.get_base_url()?;
118    let scope_part = scope.unwrap_or("*");
119    let url = format!("{}/api/v1/completeness/{}", base_url, scope_part);
120
121    let mut request = self.http.get(&url);
122    if let Some(ref key) = self.config.api_key {
123      request = request.header("Authorization", format!("Bearer {}", key));
124    }
125
126    let response = request.send().await?;
127    self.handle_response(response).await
128  }
129
130  pub async fn get_gaps(
131    &self,
132    status: Option<&str>,
133    category: Option<&str>,
134  ) -> Result<Value, McpError> {
135    let base_url = self.get_base_url()?;
136    let url = format!("{}/api/v1/gaps", base_url);
137
138    let mut request = self.http.get(&url);
139    if let Some(ref key) = self.config.api_key {
140      request = request.header("Authorization", format!("Bearer {}", key));
141    }
142
143    if let Some(s) = status {
144      request = request.query(&[("status", s)]);
145    }
146    if let Some(c) = category {
147      request = request.query(&[("category", c)]);
148    }
149
150    let response = request.send().await?;
151    self.handle_response(response).await
152  }
153
154  pub async fn detect_gaps(&self, category: Option<&str>) -> Result<Value, McpError> {
155    let base_url = self.get_base_url()?;
156    let url = format!("{}/api/v1/gaps/detect", base_url);
157
158    let body = if let Some(cat) = category {
159      serde_json::json!({ "category": cat })
160    } else {
161      serde_json::json!({})
162    };
163
164    let mut request = self.http.post(&url).json(&body);
165    if let Some(ref key) = self.config.api_key {
166      request = request.header("Authorization", format!("Bearer {}", key));
167    }
168
169    let response = request.send().await?;
170    self.handle_response(response).await
171  }
172
173  fn get_base_url(&self) -> Result<String, McpError> {
174    self.config.endpoint.clone().ok_or(McpError::NoEndpoint)
175  }
176
177  async fn handle_response(&self, response: reqwest::Response) -> Result<Value, McpError> {
178    if !response.status().is_success() {
179      let status = response.status();
180      let text = response.text().await.unwrap_or_default();
181      return Err(McpError::ServerError(format!("HTTP {}: {}", status, text)));
182    }
183
184    let result: Value = response.json().await?;
185    Ok(result)
186  }
187}