systemprompt-mcp 0.1.2

Core MCP (Model Context Protocol) functionality for systemprompt.io OS
Documentation
use serde::{Deserialize, Serialize};

#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct CspPolicy {
    pub default_src: Vec<String>,
    pub script_src: Vec<String>,
    pub style_src: Vec<String>,
    pub img_src: Vec<String>,
    pub connect_src: Vec<String>,
    pub font_src: Vec<String>,
    pub frame_src: Vec<String>,
    pub base_uri: Vec<String>,
}

impl CspPolicy {
    pub fn strict() -> Self {
        Self {
            default_src: vec!["'self'".to_string()],
            script_src: vec!["'self'".to_string(), "'unsafe-inline'".to_string()],
            style_src: vec!["'self'".to_string(), "'unsafe-inline'".to_string()],
            img_src: vec!["'self'".to_string(), "data:".to_string()],
            connect_src: vec!["'self'".to_string()],
            font_src: vec!["'self'".to_string()],
            frame_src: vec!["'none'".to_string()],
            base_uri: vec!["'self'".to_string()],
        }
    }

    pub fn with_cdn(cdn_origins: &[&str]) -> Self {
        let mut policy = Self::strict();
        for origin in cdn_origins {
            policy.script_src.push((*origin).to_string());
            policy.style_src.push((*origin).to_string());
        }
        policy
    }

    pub fn to_header_value(&self) -> String {
        let mut directives = Vec::new();

        if !self.default_src.is_empty() {
            directives.push(format!("default-src {}", self.default_src.join(" ")));
        }
        if !self.script_src.is_empty() {
            directives.push(format!("script-src {}", self.script_src.join(" ")));
        }
        if !self.style_src.is_empty() {
            directives.push(format!("style-src {}", self.style_src.join(" ")));
        }
        if !self.img_src.is_empty() {
            directives.push(format!("img-src {}", self.img_src.join(" ")));
        }
        if !self.connect_src.is_empty() {
            directives.push(format!("connect-src {}", self.connect_src.join(" ")));
        }
        if !self.font_src.is_empty() {
            directives.push(format!("font-src {}", self.font_src.join(" ")));
        }
        if !self.frame_src.is_empty() {
            directives.push(format!("frame-src {}", self.frame_src.join(" ")));
        }
        if !self.base_uri.is_empty() {
            directives.push(format!("base-uri {}", self.base_uri.join(" ")));
        }

        directives.join("; ")
    }
}

#[derive(Debug, Default)]
pub struct CspBuilder {
    policy: CspPolicy,
}

impl CspBuilder {
    pub fn new() -> Self {
        Self::default()
    }

    pub fn strict() -> Self {
        Self {
            policy: CspPolicy::strict(),
        }
    }

    pub fn default_src(mut self, sources: Vec<String>) -> Self {
        self.policy.default_src = sources;
        self
    }

    pub fn script_src(mut self, sources: Vec<String>) -> Self {
        self.policy.script_src = sources;
        self
    }

    pub fn add_script_src(mut self, source: &str) -> Self {
        self.policy.script_src.push(source.to_string());
        self
    }

    pub fn style_src(mut self, sources: Vec<String>) -> Self {
        self.policy.style_src = sources;
        self
    }

    pub fn add_style_src(mut self, source: &str) -> Self {
        self.policy.style_src.push(source.to_string());
        self
    }

    pub fn img_src(mut self, sources: Vec<String>) -> Self {
        self.policy.img_src = sources;
        self
    }

    pub fn connect_src(mut self, sources: Vec<String>) -> Self {
        self.policy.connect_src = sources;
        self
    }

    pub fn add_connect_src(mut self, source: &str) -> Self {
        self.policy.connect_src.push(source.to_string());
        self
    }

    pub fn font_src(mut self, sources: Vec<String>) -> Self {
        self.policy.font_src = sources;
        self
    }

    pub fn frame_src(mut self, sources: Vec<String>) -> Self {
        self.policy.frame_src = sources;
        self
    }

    pub fn base_uri(mut self, sources: Vec<String>) -> Self {
        self.policy.base_uri = sources;
        self
    }

    pub fn build(self) -> CspPolicy {
        self.policy
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_strict_policy() {
        let policy = CspPolicy::strict();
        let header = policy.to_header_value();

        assert!(header.contains("default-src 'self'"));
        assert!(header.contains("script-src 'self' 'unsafe-inline'"));
        assert!(header.contains("frame-src 'none'"));
    }

    #[test]
    fn test_cdn_policy() {
        let policy = CspPolicy::with_cdn(&["https://cdn.jsdelivr.net"]);
        let header = policy.to_header_value();

        assert!(header.contains("https://cdn.jsdelivr.net"));
    }

    #[test]
    fn test_builder() {
        let policy = CspBuilder::strict()
            .add_script_src("https://example.com")
            .add_connect_src("wss://api.example.com")
            .build();

        let header = policy.to_header_value();
        assert!(header.contains("https://example.com"));
        assert!(header.contains("wss://api.example.com"));
    }
}