1use std::marker::PhantomData;
31use std::time::{Duration, Instant};
32
33use crate::capability::{
34 ClientCapabilities, ClientInfo, InitializeRequest, InitializeResult, ServerCapabilities,
35 ServerInfo,
36};
37use crate::error::McpError;
38use crate::protocol::RequestId;
39
40#[derive(Debug, Clone, Copy, PartialEq, Eq)]
42pub struct Disconnected;
43
44#[derive(Debug, Clone, Copy, PartialEq, Eq)]
46pub struct Connected;
47
48#[derive(Debug, Clone, Copy, PartialEq, Eq)]
50pub struct Initializing;
51
52#[derive(Debug, Clone, Copy, PartialEq, Eq)]
54pub struct Ready;
55
56#[derive(Debug, Clone, Copy, PartialEq, Eq)]
58pub struct Closing;
59
60#[derive(Debug)]
62pub struct ConnectionInner {
63 pub id: String,
65 pub connected_at: Option<Instant>,
67 pub last_activity: Option<Instant>,
69 pub request_counter: u64,
71 pub client_info: Option<ClientInfo>,
73 pub server_info: Option<ServerInfo>,
75 pub client_capabilities: Option<ClientCapabilities>,
77 pub server_capabilities: Option<ServerCapabilities>,
79}
80
81impl ConnectionInner {
82 fn new() -> Self {
84 Self {
85 id: uuid::Uuid::new_v4().to_string(),
86 connected_at: None,
87 last_activity: None,
88 request_counter: 0,
89 client_info: None,
90 server_info: None,
91 client_capabilities: None,
92 server_capabilities: None,
93 }
94 }
95
96 fn next_request_id(&mut self) -> RequestId {
98 self.request_counter += 1;
99 RequestId::Number(self.request_counter)
100 }
101
102 fn touch(&mut self) {
104 self.last_activity = Some(Instant::now());
105 }
106}
107
108impl Default for ConnectionInner {
109 fn default() -> Self {
110 Self::new()
111 }
112}
113
114#[derive(Debug)]
119pub struct Connection<S> {
120 inner: ConnectionInner,
121 _state: PhantomData<S>,
122}
123
124impl Connection<Disconnected> {
125 #[must_use]
127 pub fn new() -> Self {
128 Self {
129 inner: ConnectionInner::new(),
130 _state: PhantomData,
131 }
132 }
133
134 #[must_use]
136 pub fn id(&self) -> &str {
137 &self.inner.id
138 }
139
140 #[must_use]
145 pub fn connect(mut self) -> Connection<Connected> {
146 self.inner.connected_at = Some(Instant::now());
147 self.inner.touch();
148 Connection {
149 inner: self.inner,
150 _state: PhantomData,
151 }
152 }
153}
154
155impl Default for Connection<Disconnected> {
156 fn default() -> Self {
157 Self::new()
158 }
159}
160
161impl Connection<Connected> {
162 #[must_use]
164 pub fn id(&self) -> &str {
165 &self.inner.id
166 }
167
168 #[must_use]
170 pub fn connected_at(&self) -> Option<Instant> {
171 self.inner.connected_at
172 }
173
174 #[must_use]
176 pub fn uptime(&self) -> Duration {
177 self.inner
178 .connected_at
179 .map(|t| t.elapsed())
180 .unwrap_or_default()
181 }
182
183 pub fn initialize(
188 mut self,
189 client_info: ClientInfo,
190 client_capabilities: ClientCapabilities,
191 ) -> (Connection<Initializing>, InitializeRequest) {
192 self.inner.client_info = Some(client_info.clone());
193 self.inner.client_capabilities = Some(client_capabilities.clone());
194 self.inner.touch();
195
196 let request = InitializeRequest::new(client_info, client_capabilities);
197
198 (
199 Connection {
200 inner: self.inner,
201 _state: PhantomData,
202 },
203 request,
204 )
205 }
206
207 #[must_use]
209 pub fn disconnect(self) -> Connection<Disconnected> {
210 Connection {
211 inner: ConnectionInner::new(),
212 _state: PhantomData,
213 }
214 }
215}
216
217impl Connection<Initializing> {
218 #[must_use]
220 pub fn id(&self) -> &str {
221 &self.inner.id
222 }
223
224 #[must_use]
226 pub fn client_info(&self) -> Option<&ClientInfo> {
227 self.inner.client_info.as_ref()
228 }
229
230 #[must_use]
232 pub fn client_capabilities(&self) -> Option<&ClientCapabilities> {
233 self.inner.client_capabilities.as_ref()
234 }
235
236 pub fn complete(
241 mut self,
242 server_info: ServerInfo,
243 server_capabilities: ServerCapabilities,
244 ) -> Connection<Ready> {
245 self.inner.server_info = Some(server_info);
246 self.inner.server_capabilities = Some(server_capabilities);
247 self.inner.touch();
248
249 Connection {
250 inner: self.inner,
251 _state: PhantomData,
252 }
253 }
254
255 #[must_use]
257 pub fn abort(self) -> Connection<Disconnected> {
258 Connection {
259 inner: ConnectionInner::new(),
260 _state: PhantomData,
261 }
262 }
263}
264
265impl Connection<Ready> {
266 #[must_use]
268 pub fn id(&self) -> &str {
269 &self.inner.id
270 }
271
272 #[must_use]
274 pub fn connected_at(&self) -> Option<Instant> {
275 self.inner.connected_at
276 }
277
278 #[must_use]
280 pub fn uptime(&self) -> Duration {
281 self.inner
282 .connected_at
283 .map(|t| t.elapsed())
284 .unwrap_or_default()
285 }
286
287 #[must_use]
289 pub fn last_activity(&self) -> Option<Instant> {
290 self.inner.last_activity
291 }
292
293 #[must_use]
301 pub fn client_info(&self) -> &ClientInfo {
302 self.inner.client_info.as_ref().expect("client_info should be set in Ready state")
303 }
304
305 #[must_use]
309 pub fn try_client_info(&self) -> Option<&ClientInfo> {
310 self.inner.client_info.as_ref()
311 }
312
313 #[must_use]
321 pub fn server_info(&self) -> &ServerInfo {
322 self.inner.server_info.as_ref().expect("server_info should be set in Ready state")
323 }
324
325 #[must_use]
329 pub fn try_server_info(&self) -> Option<&ServerInfo> {
330 self.inner.server_info.as_ref()
331 }
332
333 #[must_use]
341 pub fn client_capabilities(&self) -> &ClientCapabilities {
342 self.inner.client_capabilities.as_ref().expect("client_capabilities should be set in Ready state")
343 }
344
345 #[must_use]
349 pub fn try_client_capabilities(&self) -> Option<&ClientCapabilities> {
350 self.inner.client_capabilities.as_ref()
351 }
352
353 #[must_use]
361 pub fn server_capabilities(&self) -> &ServerCapabilities {
362 self.inner.server_capabilities.as_ref().expect("server_capabilities should be set in Ready state")
363 }
364
365 #[must_use]
369 pub fn try_server_capabilities(&self) -> Option<&ServerCapabilities> {
370 self.inner.server_capabilities.as_ref()
371 }
372
373 pub fn next_request_id(&mut self) -> RequestId {
375 self.inner.next_request_id()
376 }
377
378 pub fn touch(&mut self) {
380 self.inner.touch();
381 }
382
383 #[must_use]
385 pub fn is_idle(&self, timeout: Duration) -> bool {
386 self.inner
387 .last_activity
388 .map(|t| t.elapsed() > timeout)
389 .unwrap_or(false)
390 }
391
392 #[must_use]
394 pub fn shutdown(self) -> Connection<Closing> {
395 Connection {
396 inner: self.inner,
397 _state: PhantomData,
398 }
399 }
400}
401
402impl Connection<Closing> {
403 #[must_use]
405 pub fn id(&self) -> &str {
406 &self.inner.id
407 }
408
409 #[must_use]
411 pub fn close(self) -> Connection<Disconnected> {
412 Connection {
413 inner: ConnectionInner::new(),
414 _state: PhantomData,
415 }
416 }
417}
418
419pub struct InitializeResultBuilder {
421 server_info: ServerInfo,
422 capabilities: ServerCapabilities,
423 instructions: Option<String>,
424}
425
426impl InitializeResultBuilder {
427 #[must_use]
429 pub fn new(name: impl Into<String>, version: impl Into<String>) -> Self {
430 Self {
431 server_info: ServerInfo::new(name, version),
432 capabilities: ServerCapabilities::new(),
433 instructions: None,
434 }
435 }
436
437 #[must_use]
439 pub fn capabilities(mut self, caps: ServerCapabilities) -> Self {
440 self.capabilities = caps;
441 self
442 }
443
444 #[must_use]
446 pub fn with_tools(mut self) -> Self {
447 self.capabilities = self.capabilities.with_tools();
448 self
449 }
450
451 #[must_use]
453 pub fn with_resources(mut self) -> Self {
454 self.capabilities = self.capabilities.with_resources();
455 self
456 }
457
458 #[must_use]
460 pub fn with_prompts(mut self) -> Self {
461 self.capabilities = self.capabilities.with_prompts();
462 self
463 }
464
465 #[must_use]
467 pub fn with_tasks(mut self) -> Self {
468 self.capabilities = self.capabilities.with_tasks();
469 self
470 }
471
472 #[must_use]
474 pub fn instructions(mut self, instructions: impl Into<String>) -> Self {
475 self.instructions = Some(instructions.into());
476 self
477 }
478
479 #[must_use]
481 pub fn build(self) -> InitializeResult {
482 let mut result = InitializeResult::new(self.server_info, self.capabilities);
483 if let Some(instructions) = self.instructions {
484 result = result.instructions(instructions);
485 }
486 result
487 }
488}
489
490pub fn validate_initialization(
492 _client_caps: &ClientCapabilities,
493 _server_caps: &ServerCapabilities,
494) -> Result<(), McpError> {
495 Ok(())
498}
499
500#[cfg(test)]
501mod tests {
502 use super::*;
503
504 #[test]
505 fn test_connection_lifecycle() {
506 let conn: Connection<Disconnected> = Connection::new();
508 assert!(!conn.id().is_empty());
509
510 let conn: Connection<Connected> = conn.connect();
512 assert!(conn.connected_at().is_some());
513
514 let client = ClientInfo::new("test", "1.0.0");
516 let caps = ClientCapabilities::new();
517 let (conn, request): (Connection<Initializing>, _) = conn.initialize(client, caps);
518 assert!(conn.client_info().is_some());
519
520 let server = ServerInfo::new("server", "1.0.0");
522 let server_caps = ServerCapabilities::new().with_tools();
523 let mut conn: Connection<Ready> = conn.complete(server, server_caps);
524 assert!(conn.server_capabilities().has_tools());
525
526 let id1 = conn.next_request_id();
528 let id2 = conn.next_request_id();
529 assert_ne!(id1, id2);
530
531 let conn: Connection<Closing> = conn.shutdown();
533 let _conn: Connection<Disconnected> = conn.close();
534 }
535
536 #[test]
537 fn test_uptime() {
538 let conn = Connection::new().connect();
539 std::thread::sleep(std::time::Duration::from_millis(10));
540 assert!(conn.uptime() >= std::time::Duration::from_millis(10));
541 }
542
543 #[test]
544 fn test_idle_detection() {
545 let client = ClientInfo::new("test", "1.0.0");
546 let server = ServerInfo::new("server", "1.0.0");
547
548 let (conn, _) = Connection::new()
549 .connect()
550 .initialize(client, ClientCapabilities::new());
551
552 let conn = conn.complete(server, ServerCapabilities::new());
553
554 assert!(!conn.is_idle(Duration::from_secs(1)));
556 }
557
558 #[test]
559 fn test_initialize_result_builder() {
560 let result = InitializeResultBuilder::new("my-server", "1.0.0")
561 .with_tools()
562 .with_resources()
563 .instructions("Use this server to access tools and resources")
564 .build();
565
566 assert_eq!(result.server_info.name, "my-server");
567 assert!(result.capabilities.has_tools());
568 assert!(result.capabilities.has_resources());
569 assert!(result.instructions.is_some());
570 }
571
572 #[test]
573 fn test_abort_initialization() {
574 let client = ClientInfo::new("test", "1.0.0");
575 let (conn, _) = Connection::new()
576 .connect()
577 .initialize(client, ClientCapabilities::new());
578
579 let _conn: Connection<Disconnected> = conn.abort();
581 }
582
583 #[test]
584 fn test_disconnect_from_connected() {
585 let conn = Connection::new().connect();
586 let _conn: Connection<Disconnected> = conn.disconnect();
587 }
588
589 #[test]
590 fn test_fallible_accessors() {
591 let client = ClientInfo::new("test-client", "1.0.0");
592 let server = ServerInfo::new("test-server", "2.0.0");
593 let client_caps = ClientCapabilities::new();
594 let server_caps = ServerCapabilities::new().with_tools();
595
596 let (conn, _) = Connection::new()
597 .connect()
598 .initialize(client.clone(), client_caps.clone());
599
600 let conn = conn.complete(server.clone(), server_caps.clone());
601
602 assert!(conn.try_client_info().is_some());
604 assert!(conn.try_server_info().is_some());
605 assert!(conn.try_client_capabilities().is_some());
606 assert!(conn.try_server_capabilities().is_some());
607
608 assert_eq!(conn.try_client_info().unwrap().name, "test-client");
610 assert_eq!(conn.try_server_info().unwrap().name, "test-server");
611 assert!(conn.try_server_capabilities().unwrap().has_tools());
612 }
613}