1use serde::{Deserialize, Serialize};
8use std::path::PathBuf;
9use std::sync::Arc;
10use tokio::sync::RwLock;
11
12#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
14#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
15pub enum CompressionMode {
16 None,
18 Light,
20 Standard,
22 Quantum,
24 QuantumSemantic,
26 Auto,
28}
29
30impl CompressionMode {
31 pub fn from_env() -> Self {
33 std::env::var("ST_COMPRESSION")
34 .ok()
35 .and_then(|s| match s.to_lowercase().as_str() {
36 "none" | "raw" => Some(Self::None),
37 "light" => Some(Self::Light),
38 "standard" | "normal" => Some(Self::Standard),
39 "quantum" => Some(Self::Quantum),
40 "quantum-semantic" | "max" => Some(Self::QuantumSemantic),
41 "auto" => Some(Self::Auto),
42 _ => None,
43 })
44 .unwrap_or(Self::Auto)
45 }
46
47 pub fn auto_select(file_count: usize) -> Self {
49 match file_count {
50 0..=50 => Self::None, 51..=200 => Self::Light, 201..=500 => Self::Standard, 501..=1000 => Self::Quantum, _ => Self::QuantumSemantic, }
56 }
57
58 pub fn to_output_mode(&self) -> &'static str {
60 match self {
61 Self::None => "classic",
62 Self::Light => "ai",
63 Self::Standard => "summary-ai",
64 Self::Quantum => "quantum",
65 Self::QuantumSemantic => "quantum-semantic",
66 Self::Auto => "auto",
67 }
68 }
69}
70
71#[derive(Debug, Clone, Serialize, Deserialize)]
73pub struct SessionPreferences {
74 pub format: CompressionMode,
76 pub depth: DepthMode,
78 pub tools: ToolAdvertisement,
80 pub project_path: Option<PathBuf>,
82}
83
84impl Default for SessionPreferences {
85 fn default() -> Self {
86 Self {
87 format: CompressionMode::Auto,
88 depth: DepthMode::Adaptive,
89 tools: ToolAdvertisement::Lazy,
90 project_path: None,
91 }
92 }
93}
94
95#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
97#[serde(rename_all = "snake_case")]
98pub enum DepthMode {
99 Shallow,
101 Standard,
103 Deep,
105 Adaptive,
107}
108
109impl DepthMode {
110 pub fn to_depth(&self, dir_count: usize) -> usize {
111 match self {
112 Self::Shallow => 2,
113 Self::Standard => 4,
114 Self::Deep => 10,
115 Self::Adaptive => {
116 match dir_count {
118 0..=10 => 10, 11..=50 => 5, 51..=100 => 4, _ => 3, }
123 }
124 }
125 }
126}
127
128#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
130#[serde(rename_all = "snake_case")]
131pub enum ToolAdvertisement {
132 All,
134 Lazy,
136 ContextAware,
138 Minimal,
140}
141
142#[derive(Debug, Clone)]
144pub struct McpSession {
145 pub id: String,
147 pub preferences: SessionPreferences,
149 pub project_path: PathBuf,
151 pub negotiated: bool,
153 pub started_at: std::time::SystemTime,
155}
156
157impl Default for McpSession {
158 fn default() -> Self {
159 Self::new()
160 }
161}
162
163impl McpSession {
164 pub fn new() -> Self {
166 Self {
167 id: format!("STX-{:x}", rand::random::<u32>()),
168 preferences: SessionPreferences::default(),
169 project_path: std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")),
170 negotiated: false,
171 started_at: std::time::SystemTime::now(),
172 }
173 }
174
175 pub fn from_context(initial_path: Option<PathBuf>) -> Self {
177 let mut session = Self::new();
178
179 if let Some(path) = initial_path {
181 session.project_path = path;
182 } else if let Ok(cwd) = std::env::current_dir() {
183 if cwd.join("Cargo.toml").exists()
185 || cwd.join("package.json").exists()
186 || cwd.join("pyproject.toml").exists()
187 || cwd.join(".git").exists()
188 {
189 session.project_path = cwd;
190 }
191 }
192
193 session.preferences.format = CompressionMode::from_env();
195
196 session
197 }
198
199 pub fn negotiate(&mut self, client_prefs: Option<SessionPreferences>) -> NegotiationResponse {
201 if let Some(prefs) = client_prefs {
202 self.preferences = prefs;
203 self.negotiated = true;
204
205 NegotiationResponse {
206 session_id: self.id.clone(),
207 accepted: true,
208 format: self.preferences.format,
209 project_path: self.project_path.clone(),
210 tools_available: self.get_available_tools(),
211 }
212 } else {
213 NegotiationResponse {
215 session_id: self.id.clone(),
216 accepted: false,
217 format: self.preferences.format,
218 project_path: self.project_path.clone(),
219 tools_available: vec!["overview".to_string(), "find".to_string()],
220 }
221 }
222 }
223
224 pub fn get_available_tools(&self) -> Vec<String> {
226 match self.preferences.tools {
227 ToolAdvertisement::All => {
228 vec![
230 "overview",
231 "find",
232 "search",
233 "analyze",
234 "edit",
235 "history",
236 "context",
237 "memory",
238 "compare",
239 "feedback",
240 "server_info",
241 "verify_permissions",
242 "sse",
243 ]
245 .into_iter()
246 .map(String::from)
247 .collect()
248 }
249 ToolAdvertisement::Lazy => {
250 vec!["overview", "find", "search"]
252 .into_iter()
253 .map(String::from)
254 .collect()
255 }
256 ToolAdvertisement::ContextAware => {
257 let mut tools = vec!["overview", "find", "search"];
259
260 if self.project_path.join("Cargo.toml").exists() {
262 tools.push("analyze"); }
264 if self.project_path.join(".git").exists() {
265 tools.push("history"); }
267
268 tools.into_iter().map(String::from).collect()
269 }
270 ToolAdvertisement::Minimal => {
271 vec!["overview"].into_iter().map(String::from).collect()
273 }
274 }
275 }
276
277 pub fn apply_context(&self, tool_name: &str, params: &mut serde_json::Value) {
279 if let Some(obj) = params.as_object_mut() {
281 if !obj.contains_key("path") {
282 obj.insert(
283 "path".to_string(),
284 serde_json::Value::String(self.project_path.to_string_lossy().to_string()),
285 );
286 }
287
288 if tool_name == "overview" && !obj.contains_key("mode") {
290 obj.insert(
291 "mode".to_string(),
292 serde_json::Value::String(self.preferences.format.to_output_mode().to_string()),
293 );
294 }
295 }
296 }
297}
298
299#[derive(Debug, Serialize, Deserialize)]
301pub struct NegotiationResponse {
302 pub session_id: String,
303 pub accepted: bool,
304 pub format: CompressionMode,
305 pub project_path: PathBuf,
306 pub tools_available: Vec<String>,
307}
308
309#[derive(Debug, Serialize, Deserialize)]
311pub struct NegotiationRequest {
312 pub session_prefs: Option<SessionPreferences>,
313 pub capabilities: Vec<String>,
314}
315
316pub struct SessionManager {
318 sessions: Arc<RwLock<std::collections::HashMap<String, McpSession>>>,
319}
320
321impl Default for SessionManager {
322 fn default() -> Self {
323 Self::new()
324 }
325}
326
327impl SessionManager {
328 pub fn new() -> Self {
329 Self {
330 sessions: Arc::new(RwLock::new(std::collections::HashMap::new())),
331 }
332 }
333
334 pub async fn get_or_create(&self, session_id: Option<String>) -> McpSession {
336 let mut sessions = self.sessions.write().await;
337
338 if let Some(id) = session_id {
339 if let Some(session) = sessions.get(&id) {
340 return session.clone();
341 }
342 }
343
344 let session = McpSession::from_context(None);
346 sessions.insert(session.id.clone(), session.clone());
347 session
348 }
349
350 pub async fn update(&self, session: McpSession) {
352 let mut sessions = self.sessions.write().await;
353 sessions.insert(session.id.clone(), session);
354 }
355
356 pub async fn cleanup(&self) {
358 let mut sessions = self.sessions.write().await;
359 let now = std::time::SystemTime::now();
360
361 sessions.retain(|_, session| {
362 if let Ok(duration) = now.duration_since(session.started_at) {
363 duration.as_secs() < 3600 } else {
365 true
366 }
367 });
368 }
369}
370
371use rand;
373
374#[cfg(test)]
375mod tests {
376 use super::*;
377
378 #[test]
379 fn test_compression_auto_select() {
380 assert_eq!(CompressionMode::auto_select(10), CompressionMode::None);
381 assert_eq!(CompressionMode::auto_select(100), CompressionMode::Light);
382 assert_eq!(CompressionMode::auto_select(300), CompressionMode::Standard);
383 assert_eq!(CompressionMode::auto_select(700), CompressionMode::Quantum);
384 assert_eq!(
385 CompressionMode::auto_select(2000),
386 CompressionMode::QuantumSemantic
387 );
388 }
389
390 #[test]
391 fn test_depth_adaptive() {
392 let depth = DepthMode::Adaptive;
393 assert_eq!(depth.to_depth(5), 10); assert_eq!(depth.to_depth(30), 5); assert_eq!(depth.to_depth(200), 3); }
397}