1use std::marker::PhantomData;
31use std::time::{Duration, Instant};
32
33use crate::capability::{
34 ClientCapabilities, ClientInfo, InitializeRequest, InitializeResult, ServerCapabilities,
35 ServerInfo,
36};
37use crate::error::McpError;
38use crate::protocol::RequestId;
39
40#[derive(Debug, Clone, Copy, PartialEq, Eq)]
42pub struct Disconnected;
43
44#[derive(Debug, Clone, Copy, PartialEq, Eq)]
46pub struct Connected;
47
48#[derive(Debug, Clone, Copy, PartialEq, Eq)]
50pub struct Initializing;
51
52#[derive(Debug, Clone, Copy, PartialEq, Eq)]
54pub struct Ready;
55
56#[derive(Debug, Clone, Copy, PartialEq, Eq)]
58pub struct Closing;
59
60#[derive(Debug)]
62pub struct ConnectionInner {
63 pub id: String,
65 pub connected_at: Option<Instant>,
67 pub last_activity: Option<Instant>,
69 pub request_counter: u64,
71 pub client_info: Option<ClientInfo>,
73 pub server_info: Option<ServerInfo>,
75 pub client_capabilities: Option<ClientCapabilities>,
77 pub server_capabilities: Option<ServerCapabilities>,
79}
80
81impl ConnectionInner {
82 fn new() -> Self {
84 Self {
85 id: uuid::Uuid::new_v4().to_string(),
86 connected_at: None,
87 last_activity: None,
88 request_counter: 0,
89 client_info: None,
90 server_info: None,
91 client_capabilities: None,
92 server_capabilities: None,
93 }
94 }
95
96 fn next_request_id(&mut self) -> RequestId {
98 self.request_counter += 1;
99 RequestId::Number(self.request_counter)
100 }
101
102 fn touch(&mut self) {
104 self.last_activity = Some(Instant::now());
105 }
106}
107
108impl Default for ConnectionInner {
109 fn default() -> Self {
110 Self::new()
111 }
112}
113
114#[derive(Debug)]
119pub struct Connection<S> {
120 inner: ConnectionInner,
121 _state: PhantomData<S>,
122}
123
124impl Connection<Disconnected> {
125 #[must_use]
127 pub fn new() -> Self {
128 Self {
129 inner: ConnectionInner::new(),
130 _state: PhantomData,
131 }
132 }
133
134 #[must_use]
136 pub fn id(&self) -> &str {
137 &self.inner.id
138 }
139
140 #[must_use]
145 pub fn connect(mut self) -> Connection<Connected> {
146 self.inner.connected_at = Some(Instant::now());
147 self.inner.touch();
148 Connection {
149 inner: self.inner,
150 _state: PhantomData,
151 }
152 }
153}
154
155impl Default for Connection<Disconnected> {
156 fn default() -> Self {
157 Self::new()
158 }
159}
160
161impl Connection<Connected> {
162 #[must_use]
164 pub fn id(&self) -> &str {
165 &self.inner.id
166 }
167
168 #[must_use]
170 pub const fn connected_at(&self) -> Option<Instant> {
171 self.inner.connected_at
172 }
173
174 #[must_use]
176 pub fn uptime(&self) -> Duration {
177 self.inner
178 .connected_at
179 .map(|t| t.elapsed())
180 .unwrap_or_default()
181 }
182
183 #[must_use]
188 pub fn initialize(
189 mut self,
190 client_info: ClientInfo,
191 client_capabilities: ClientCapabilities,
192 ) -> (Connection<Initializing>, InitializeRequest) {
193 self.inner.client_info = Some(client_info.clone());
194 self.inner.client_capabilities = Some(client_capabilities.clone());
195 self.inner.touch();
196
197 let request = InitializeRequest::new(client_info, client_capabilities);
198
199 (
200 Connection {
201 inner: self.inner,
202 _state: PhantomData,
203 },
204 request,
205 )
206 }
207
208 #[must_use]
210 pub fn disconnect(self) -> Connection<Disconnected> {
211 Connection {
212 inner: ConnectionInner::new(),
213 _state: PhantomData,
214 }
215 }
216}
217
218impl Connection<Initializing> {
219 #[must_use]
221 pub fn id(&self) -> &str {
222 &self.inner.id
223 }
224
225 #[must_use]
227 pub const fn client_info(&self) -> Option<&ClientInfo> {
228 self.inner.client_info.as_ref()
229 }
230
231 #[must_use]
233 pub const fn client_capabilities(&self) -> Option<&ClientCapabilities> {
234 self.inner.client_capabilities.as_ref()
235 }
236
237 #[must_use]
242 pub fn complete(
243 mut self,
244 server_info: ServerInfo,
245 server_capabilities: ServerCapabilities,
246 ) -> Connection<Ready> {
247 self.inner.server_info = Some(server_info);
248 self.inner.server_capabilities = Some(server_capabilities);
249 self.inner.touch();
250
251 Connection {
252 inner: self.inner,
253 _state: PhantomData,
254 }
255 }
256
257 #[must_use]
259 pub fn abort(self) -> Connection<Disconnected> {
260 Connection {
261 inner: ConnectionInner::new(),
262 _state: PhantomData,
263 }
264 }
265}
266
267impl Connection<Ready> {
268 #[must_use]
270 pub fn id(&self) -> &str {
271 &self.inner.id
272 }
273
274 #[must_use]
276 pub const fn connected_at(&self) -> Option<Instant> {
277 self.inner.connected_at
278 }
279
280 #[must_use]
282 pub fn uptime(&self) -> Duration {
283 self.inner
284 .connected_at
285 .map(|t| t.elapsed())
286 .unwrap_or_default()
287 }
288
289 #[must_use]
291 pub const fn last_activity(&self) -> Option<Instant> {
292 self.inner.last_activity
293 }
294
295 #[must_use]
303 pub fn client_info(&self) -> &ClientInfo {
304 self.inner
305 .client_info
306 .as_ref()
307 .expect("client_info should be set in Ready state")
308 }
309
310 #[must_use]
314 pub const fn try_client_info(&self) -> Option<&ClientInfo> {
315 self.inner.client_info.as_ref()
316 }
317
318 #[must_use]
326 pub fn server_info(&self) -> &ServerInfo {
327 self.inner
328 .server_info
329 .as_ref()
330 .expect("server_info should be set in Ready state")
331 }
332
333 #[must_use]
337 pub const fn try_server_info(&self) -> Option<&ServerInfo> {
338 self.inner.server_info.as_ref()
339 }
340
341 #[must_use]
349 pub fn client_capabilities(&self) -> &ClientCapabilities {
350 self.inner
351 .client_capabilities
352 .as_ref()
353 .expect("client_capabilities should be set in Ready state")
354 }
355
356 #[must_use]
360 pub const fn try_client_capabilities(&self) -> Option<&ClientCapabilities> {
361 self.inner.client_capabilities.as_ref()
362 }
363
364 #[must_use]
372 pub fn server_capabilities(&self) -> &ServerCapabilities {
373 self.inner
374 .server_capabilities
375 .as_ref()
376 .expect("server_capabilities should be set in Ready state")
377 }
378
379 #[must_use]
383 pub const fn try_server_capabilities(&self) -> Option<&ServerCapabilities> {
384 self.inner.server_capabilities.as_ref()
385 }
386
387 pub fn next_request_id(&mut self) -> RequestId {
389 self.inner.next_request_id()
390 }
391
392 pub fn touch(&mut self) {
394 self.inner.touch();
395 }
396
397 #[must_use]
399 pub fn is_idle(&self, timeout: Duration) -> bool {
400 self.inner
401 .last_activity
402 .is_some_and(|t| t.elapsed() > timeout)
403 }
404
405 #[must_use]
407 pub fn shutdown(self) -> Connection<Closing> {
408 Connection {
409 inner: self.inner,
410 _state: PhantomData,
411 }
412 }
413}
414
415impl Connection<Closing> {
416 #[must_use]
418 pub fn id(&self) -> &str {
419 &self.inner.id
420 }
421
422 #[must_use]
424 pub fn close(self) -> Connection<Disconnected> {
425 Connection {
426 inner: ConnectionInner::new(),
427 _state: PhantomData,
428 }
429 }
430}
431
432pub struct InitializeResultBuilder {
434 server_info: ServerInfo,
435 capabilities: ServerCapabilities,
436 instructions: Option<String>,
437}
438
439impl InitializeResultBuilder {
440 #[must_use]
442 pub fn new(name: impl Into<String>, version: impl Into<String>) -> Self {
443 Self {
444 server_info: ServerInfo::new(name, version),
445 capabilities: ServerCapabilities::new(),
446 instructions: None,
447 }
448 }
449
450 #[must_use]
452 pub fn capabilities(mut self, caps: ServerCapabilities) -> Self {
453 self.capabilities = caps;
454 self
455 }
456
457 #[must_use]
459 pub fn with_tools(mut self) -> Self {
460 self.capabilities = self.capabilities.with_tools();
461 self
462 }
463
464 #[must_use]
466 pub fn with_resources(mut self) -> Self {
467 self.capabilities = self.capabilities.with_resources();
468 self
469 }
470
471 #[must_use]
473 pub fn with_prompts(mut self) -> Self {
474 self.capabilities = self.capabilities.with_prompts();
475 self
476 }
477
478 #[must_use]
480 pub fn with_tasks(mut self) -> Self {
481 self.capabilities = self.capabilities.with_tasks();
482 self
483 }
484
485 #[must_use]
487 pub fn instructions(mut self, instructions: impl Into<String>) -> Self {
488 self.instructions = Some(instructions.into());
489 self
490 }
491
492 #[must_use]
494 pub fn build(self) -> InitializeResult {
495 let mut result = InitializeResult::new(self.server_info, self.capabilities);
496 if let Some(instructions) = self.instructions {
497 result = result.instructions(instructions);
498 }
499 result
500 }
501}
502
503pub const fn validate_initialization(
505 _client_caps: &ClientCapabilities,
506 _server_caps: &ServerCapabilities,
507) -> Result<(), McpError> {
508 Ok(())
511}
512
513#[cfg(test)]
514mod tests {
515 use super::*;
516
517 #[test]
518 fn test_connection_lifecycle() {
519 let conn: Connection<Disconnected> = Connection::new();
521 assert!(!conn.id().is_empty());
522
523 let conn: Connection<Connected> = conn.connect();
525 assert!(conn.connected_at().is_some());
526
527 let client = ClientInfo::new("test", "1.0.0");
529 let caps = ClientCapabilities::new();
530 let (conn, _request): (Connection<Initializing>, _) = conn.initialize(client, caps);
531 assert!(conn.client_info().is_some());
532
533 let server = ServerInfo::new("server", "1.0.0");
535 let server_caps = ServerCapabilities::new().with_tools();
536 let mut conn: Connection<Ready> = conn.complete(server, server_caps);
537 assert!(conn.server_capabilities().has_tools());
538
539 let id1 = conn.next_request_id();
541 let id2 = conn.next_request_id();
542 assert_ne!(id1, id2);
543
544 let conn: Connection<Closing> = conn.shutdown();
546 let _conn: Connection<Disconnected> = conn.close();
547 }
548
549 #[test]
550 fn test_uptime() {
551 let conn = Connection::new().connect();
552 std::thread::sleep(std::time::Duration::from_millis(10));
553 assert!(conn.uptime() >= std::time::Duration::from_millis(10));
554 }
555
556 #[test]
557 fn test_idle_detection() {
558 let client = ClientInfo::new("test", "1.0.0");
559 let server = ServerInfo::new("server", "1.0.0");
560
561 let (conn, _) = Connection::new()
562 .connect()
563 .initialize(client, ClientCapabilities::new());
564
565 let conn = conn.complete(server, ServerCapabilities::new());
566
567 assert!(!conn.is_idle(Duration::from_secs(1)));
569 }
570
571 #[test]
572 fn test_initialize_result_builder() {
573 let result = InitializeResultBuilder::new("my-server", "1.0.0")
574 .with_tools()
575 .with_resources()
576 .instructions("Use this server to access tools and resources")
577 .build();
578
579 assert_eq!(result.server_info.name, "my-server");
580 assert!(result.capabilities.has_tools());
581 assert!(result.capabilities.has_resources());
582 assert!(result.instructions.is_some());
583 }
584
585 #[test]
586 fn test_abort_initialization() {
587 let client = ClientInfo::new("test", "1.0.0");
588 let (conn, _) = Connection::new()
589 .connect()
590 .initialize(client, ClientCapabilities::new());
591
592 let _conn: Connection<Disconnected> = conn.abort();
594 }
595
596 #[test]
597 fn test_disconnect_from_connected() {
598 let conn = Connection::new().connect();
599 let _conn: Connection<Disconnected> = conn.disconnect();
600 }
601
602 #[test]
603 fn test_fallible_accessors() {
604 let client = ClientInfo::new("test-client", "1.0.0");
605 let server = ServerInfo::new("test-server", "2.0.0");
606 let client_caps = ClientCapabilities::new();
607 let server_caps = ServerCapabilities::new().with_tools();
608
609 let (conn, _) = Connection::new().connect().initialize(client, client_caps);
610
611 let conn = conn.complete(server, server_caps);
612
613 assert!(conn.try_client_info().is_some());
615 assert!(conn.try_server_info().is_some());
616 assert!(conn.try_client_capabilities().is_some());
617 assert!(conn.try_server_capabilities().is_some());
618
619 assert_eq!(conn.try_client_info().unwrap().name, "test-client");
621 assert_eq!(conn.try_server_info().unwrap().name, "test-server");
622 assert!(conn.try_server_capabilities().unwrap().has_tools());
623 }
624}