Skip to main content

mcp/
builder_utils.rs

1//! Utilities for MCPServerBuilder
2//!
3//! This module provides helper functionality for the MCP server builder,
4//! including IP filtering, authentication, and CIDR matching.
5
6use sha2::{Digest, Sha256};
7use std::net::IpAddr;
8use std::str::FromStr;
9use subtle::ConstantTimeEq;
10
11/// IP filter for restricting server access
12#[derive(Debug, Clone)]
13pub struct IpFilter {
14    /// List of allowed IP addresses and CIDR blocks
15    allowed: Vec<IpFilterEntry>,
16}
17
18#[derive(Debug, Clone)]
19enum IpFilterEntry {
20    /// Single IP address
21    Single(IpAddr),
22    /// CIDR block (network/prefix_length)
23    Cidr { network: IpAddr, prefix_len: u8 },
24}
25
26impl IpFilter {
27    /// Create a new empty IP filter (allows all)
28    pub fn new() -> Self {
29        Self {
30            allowed: Vec::new(),
31        }
32    }
33
34    /// Add an allowed IP address or CIDR block
35    ///
36    /// # Arguments
37    ///
38    /// * `ip_or_cidr` - Either an IP address (e.g., "127.0.0.1", "::1") or a CIDR block
39    ///   (e.g., "192.168.1.0/24", "2001:db8::/32")
40    ///
41    /// # Returns
42    ///
43    /// `Err` if the input is invalid
44    pub fn allow(&mut self, ip_or_cidr: &str) -> Result<(), String> {
45        // Try parsing as CIDR first
46        if let Some(slash_pos) = ip_or_cidr.find('/') {
47            let (network_part, prefix_part) = ip_or_cidr.split_at(slash_pos);
48            let prefix_str = &prefix_part[1..]; // Skip the '/'
49
50            let network = IpAddr::from_str(network_part)
51                .map_err(|e| format!("Invalid network address: {}", e))?;
52
53            let prefix_len: u8 = prefix_str
54                .parse()
55                .map_err(|_| format!("Invalid CIDR prefix length: {}", prefix_str))?;
56
57            // Validate prefix length based on IP type
58            let max_prefix = match network {
59                IpAddr::V4(_) => 32,
60                IpAddr::V6(_) => 128,
61            };
62
63            if prefix_len > max_prefix {
64                return Err(format!(
65                    "CIDR prefix length {} exceeds maximum {} for {:?}",
66                    prefix_len, max_prefix, network
67                ));
68            }
69
70            self.allowed.push(IpFilterEntry::Cidr {
71                network,
72                prefix_len,
73            });
74            Ok(())
75        } else {
76            // Try parsing as single IP
77            let ip =
78                IpAddr::from_str(ip_or_cidr).map_err(|e| format!("Invalid IP address: {}", e))?;
79            self.allowed.push(IpFilterEntry::Single(ip));
80            Ok(())
81        }
82    }
83
84    /// Check if an IP address is allowed
85    pub fn is_allowed(&self, ip: IpAddr) -> bool {
86        // If no restrictions, allow all
87        if self.allowed.is_empty() {
88            return true;
89        }
90
91        // Check each allowed entry
92        self.allowed.iter().any(|entry| self.matches(ip, entry))
93    }
94
95    /// Check if an IP matches a filter entry
96    fn matches(&self, ip: IpAddr, entry: &IpFilterEntry) -> bool {
97        match entry {
98            IpFilterEntry::Single(allowed_ip) => ip == *allowed_ip,
99            IpFilterEntry::Cidr {
100                network,
101                prefix_len,
102            } => self.ip_in_cidr(ip, *network, *prefix_len),
103        }
104    }
105
106    /// Check if an IP is in a CIDR block
107    fn ip_in_cidr(&self, ip: IpAddr, network: IpAddr, prefix_len: u8) -> bool {
108        match (ip, network) {
109            (IpAddr::V4(ip), IpAddr::V4(net)) => {
110                let ip_bits = u32::from(ip);
111                let net_bits = u32::from(net);
112                let mask = if prefix_len == 0 {
113                    0
114                } else {
115                    0xFFFFFFFFu32 << (32 - prefix_len)
116                };
117                (ip_bits & mask) == (net_bits & mask)
118            }
119            (IpAddr::V6(ip), IpAddr::V6(net)) => {
120                let ip_bits = u128::from(ip);
121                let net_bits = u128::from(net);
122                let mask = if prefix_len == 0 {
123                    0
124                } else {
125                    0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFu128 << (128 - prefix_len)
126                };
127                (ip_bits & mask) == (net_bits & mask)
128            }
129            _ => false, // IPv4 vs IPv6 mismatch
130        }
131    }
132}
133
134impl Default for IpFilter {
135    fn default() -> Self {
136        Self::new()
137    }
138}
139
140/// Authentication configuration for MCP server
141#[derive(Debug, Clone)]
142pub enum AuthConfig {
143    /// No authentication required
144    None,
145    /// Bearer token authentication
146    Bearer(String),
147    /// Basic authentication (username:password base64 encoded)
148    Basic {
149        /// Expected username.
150        username: String,
151        /// Expected password.
152        password: String,
153    },
154}
155
156impl AuthConfig {
157    /// Create bearer token authentication
158    pub fn bearer(token: impl Into<String>) -> Self {
159        Self::Bearer(token.into())
160    }
161
162    /// Create basic authentication
163    pub fn basic(username: impl Into<String>, password: impl Into<String>) -> Self {
164        Self::Basic {
165            username: username.into(),
166            password: password.into(),
167        }
168    }
169
170    /// Validate an Authorization header
171    ///
172    /// # Arguments
173    ///
174    /// * `header` - The Authorization header value (e.g., "Bearer token123")
175    ///
176    /// # Returns
177    ///
178    /// `true` if the header matches the configured authentication
179    pub fn validate(&self, header: &str) -> bool {
180        match self {
181            AuthConfig::None => true,
182            AuthConfig::Bearer(token) => {
183                if let Some(token_part) = header.strip_prefix("Bearer ") {
184                    // subtle::ConstantTimeEq on SHA-256 digests prevents timing oracle attacks.
185                    // The optimizer cannot short-circuit ct_eq() the way it can with `==`.
186                    let expected_hash = Sha256::digest(token.as_bytes());
187                    let provided_hash = Sha256::digest(token_part.as_bytes());
188                    expected_hash.ct_eq(&provided_hash).into()
189                } else {
190                    false
191                }
192            }
193            AuthConfig::Basic { username, password } => {
194                if let Some(creds_part) = header.strip_prefix("Basic ") {
195                    // Decode base64 and check against username:password
196                    if let Ok(decoded) = base64_decode(creds_part) {
197                        let expected = format!("{}:{}", username, password);
198                        // subtle::ConstantTimeEq prevents timing oracle on credentials.
199                        let expected_hash = Sha256::digest(expected.as_bytes());
200                        let decoded_hash = Sha256::digest(decoded.as_bytes());
201                        expected_hash.ct_eq(&decoded_hash).into()
202                    } else {
203                        false
204                    }
205                } else {
206                    false
207                }
208            }
209        }
210    }
211}
212
213/// Decode base64 string
214fn base64_decode(s: &str) -> Result<String, String> {
215    // Simple base64 decoding without external dependencies
216    const BASE64_TABLE: &[u8; 64] =
217        b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
218    let mut table = [255u8; 256];
219    for (i, &c) in BASE64_TABLE.iter().enumerate() {
220        table[c as usize] = i as u8;
221    }
222
223    let input = s.trim_end_matches('=');
224    let mut output = Vec::new();
225    let bytes = input.as_bytes();
226
227    for chunk in bytes.chunks(4) {
228        if chunk.len() < 2 {
229            break;
230        }
231
232        let mut buf = [0u8; 4];
233        for (i, &c) in chunk.iter().enumerate() {
234            if c == b'=' {
235                break;
236            }
237            buf[i] = table[c as usize];
238            if buf[i] == 255 {
239                return Err("Invalid base64 character".to_string());
240            }
241        }
242
243        let b1 = (buf[0] << 2) | (buf[1] >> 4);
244        output.push(b1);
245
246        if chunk.len() > 2 && chunk[2] != b'=' {
247            let b2 = ((buf[1] & 0x0F) << 4) | (buf[2] >> 2);
248            output.push(b2);
249        }
250
251        if chunk.len() > 3 && chunk[3] != b'=' {
252            let b3 = ((buf[2] & 0x03) << 6) | buf[3];
253            output.push(b3);
254        }
255    }
256
257    String::from_utf8(output).map_err(|e| e.to_string())
258}