Skip to main content

feagi_io/protocol_implementations/websocket/
shared.rs

1//! Shared utilities for WebSocket implementations.
2
3use crate::FeagiNetworkError;
4use serde::{Deserialize, Serialize};
5
6/// URL endpoint struct for WebSocket endpoints with validation.
7#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
8pub struct WebSocketUrl {
9    url: String,
10}
11
12impl WebSocketUrl {
13    /// Creates a new WebSocketUrl after validating the format.
14    ///
15    /// The URL will be normalized to include the `ws://` or `wss://` scheme if not present.
16    ///
17    /// # Arguments
18    ///
19    /// * `url` - The WebSocket URL (e.g., "ws://localhost:8080", "wss://example.com/path", "localhost:8080").
20    ///
21    /// # Errors
22    ///
23    /// Returns an error if the URL format is invalid.
24    pub fn new(url: &str) -> Result<Self, FeagiNetworkError> {
25        let normalized = normalize_ws_url(url);
26        validate_ws_url(&normalized)?;
27        Ok(WebSocketUrl { url: normalized })
28    }
29
30    /// Returns the URL as a string slice.
31    #[allow(dead_code)]
32    pub fn as_str(&self) -> &str {
33        &self.url
34    }
35
36    /// Extracts host:port from the WebSocket URL for TCP connection.
37    ///
38    /// Strips the `ws://` or `wss://` scheme and any path component,
39    /// returning just the `host:port` portion suitable for TCP connection.
40    /// If no port is specified, defaults to port 80 for ws:// or 443 for wss://.
41    pub fn host_port(&self) -> String {
42        let is_secure = self.url.starts_with("wss://");
43
44        // Remove ws:// or wss:// prefix
45        let without_scheme = self
46            .url
47            .strip_prefix("ws://")
48            .or_else(|| self.url.strip_prefix("wss://"))
49            .unwrap_or(&self.url);
50
51        // Remove any path component
52        let host_port = without_scheme.split('/').next().unwrap_or(without_scheme);
53
54        // Add default port if not specified
55        if host_port.contains(':') {
56            host_port.to_string()
57        } else {
58            // Default WebSocket ports: 80 for ws://, 443 for wss://
59            let default_port = if is_secure { 443 } else { 80 };
60            format!("{}:{}", host_port, default_port)
61        }
62    }
63
64    /// Returns whether this is a secure WebSocket URL (wss://).
65    #[allow(dead_code)]
66    pub fn is_secure(&self) -> bool {
67        self.url.starts_with("wss://")
68    }
69}
70
71impl std::fmt::Display for WebSocketUrl {
72    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
73        write!(f, "{}", self.url)
74    }
75}
76
77/// Normalizes a host string to a valid WebSocket URL.
78///
79/// If the host already starts with `ws://` or `wss://`, it's returned as-is.
80/// Otherwise, `ws://` is prepended to the host.
81///
82/// # Examples
83/// ```ignore
84/// assert_eq!(normalize_ws_url("localhost:8080"), "ws://localhost:8080");
85/// assert_eq!(normalize_ws_url("ws://localhost:8080"), "ws://localhost:8080");
86/// assert_eq!(normalize_ws_url("wss://secure.example.com"), "wss://secure.example.com");
87/// ```
88fn normalize_ws_url(url: &str) -> String {
89    if url.starts_with("ws://") || url.starts_with("wss://") {
90        url.to_string()
91    } else {
92        format!("ws://{}", url)
93    }
94}
95
96/// Validates a WebSocket URL format.
97///
98/// Valid schemes: ws://, wss://
99fn validate_ws_url(url: &str) -> Result<(), FeagiNetworkError> {
100    // Check for valid WebSocket scheme prefixes
101    const VALID_PREFIXES: [&str; 2] = ["ws://", "wss://"];
102
103    if !VALID_PREFIXES.iter().any(|prefix| url.starts_with(prefix)) {
104        return Err(FeagiNetworkError::InvalidSocketProperties(format!(
105            "Invalid WebSocket URL '{}': must start with one of {:?}",
106            url, VALID_PREFIXES
107        )));
108    }
109
110    // Extract the part after the scheme
111    let addr_part = url
112        .strip_prefix("wss://")
113        .or_else(|| url.strip_prefix("ws://"))
114        .ok_or_else(|| {
115            FeagiNetworkError::InvalidSocketProperties(format!(
116                "Invalid WebSocket URL '{}': expected {:?}",
117                url, VALID_PREFIXES
118            ))
119        })?;
120
121    if addr_part.is_empty() {
122        return Err(FeagiNetworkError::InvalidSocketProperties(format!(
123            "Invalid WebSocket URL '{}': empty address after scheme",
124            url
125        )));
126    }
127
128    // Extract host:port (before any path)
129    let host_port = addr_part.split('/').next().unwrap_or(addr_part);
130
131    if host_port.is_empty() {
132        return Err(FeagiNetworkError::InvalidSocketProperties(format!(
133            "Invalid WebSocket URL '{}': empty host",
134            url
135        )));
136    }
137
138    Ok(())
139}
140
141/// Validates a bind address format (host:port).
142///
143/// # Arguments
144///
145/// * `bind_address` - The address to bind to (e.g., "127.0.0.1:8080", "0.0.0.0:8080").
146///
147/// # Errors
148///
149/// Returns an error if the address format is invalid.
150#[allow(dead_code)]
151pub fn validate_bind_address(bind_address: &str) -> Result<(), FeagiNetworkError> {
152    if bind_address.is_empty() {
153        return Err(FeagiNetworkError::InvalidSocketProperties(
154            "Invalid bind address: empty string".to_string(),
155        ));
156    }
157
158    // Check for scheme prefixes that shouldn't be in bind addresses
159    if bind_address.starts_with("ws://")
160        || bind_address.starts_with("wss://")
161        || bind_address.starts_with("http://")
162        || bind_address.starts_with("https://")
163    {
164        return Err(FeagiNetworkError::InvalidSocketProperties(format!(
165            "Invalid bind address '{}': should be host:port without scheme",
166            bind_address
167        )));
168    }
169
170    // Should contain host:port
171    if !bind_address.contains(':') {
172        return Err(FeagiNetworkError::InvalidSocketProperties(format!(
173            "Invalid bind address '{}': missing port (expected host:port)",
174            bind_address
175        )));
176    }
177
178    Ok(())
179}