sekuire 0.1.0

The official SDK for the Sekuire Agent Identity Protocol
Documentation
use blake3;
use ed25519_dalek::{Signature, VerifyingKey};
use hex;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use thiserror::Error;

#[derive(Debug, Error)]
pub enum PolicyError {
    #[error("network error: {0}")]
    Network(#[from] reqwest::Error),
    #[error("serde error: {0}")]
    Serde(#[from] serde_json::Error),
    #[error("hash mismatch: expected {expected}, got {actual}")]
    HashMismatch { expected: String, actual: String },
    #[error("signature verification failed")]
    InvalidSignature,
    #[error("missing public key for signature verification")]
    MissingPublicKey,
    #[error("invalid signature bytes")]
    InvalidSignatureBytes,
    #[error("policy violation: {0}")]
    Violation(String),
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ActivePolicy {
    pub policy_id: String,
    pub workspace_id: String,
    pub version: String,
    pub status: String,
    pub hash: String,
    pub content: serde_json::Value,
    pub activated_at: Option<String>,
    pub updated_at: Option<String>,
    pub signature: Option<String>,
    pub signing_key_id: Option<String>,
    pub signing_public_key: Option<String>,
}

pub struct PolicyClient {
    base_url: String,
    http: Client,
}

impl PolicyClient {
    pub fn new(base_url: impl Into<String>) -> Self {
        Self {
            base_url: base_url.into(),
            http: Client::new(),
        }
    }

    pub async fn fetch_active_policy(
        &self,
        workspace_id: &str,
    ) -> Result<ActivePolicy, PolicyError> {
        let url = format!(
            "{}/api/v1/workspaces/{}/policy/active",
            self.base_url, workspace_id
        );
        let res = self.http.get(&url).send().await?.error_for_status()?;
        let policy: ActivePolicy = res.json().await?;
        self.verify(&policy)?;
        Ok(policy)
    }

    pub fn verify(&self, policy: &ActivePolicy) -> Result<(), PolicyError> {
        let calculated = blake3::hash(serde_json::to_string(&policy.content)?.as_bytes())
            .to_hex()
            .to_string();
        if calculated != policy.hash {
            return Err(PolicyError::HashMismatch {
                expected: policy.hash.clone(),
                actual: calculated,
            });
        }

        if let Some(sig_hex) = &policy.signature {
            let pub_key_hex = policy
                .signing_public_key
                .as_ref()
                .ok_or(PolicyError::MissingPublicKey)?;

            let pub_key_bytes =
                hex::decode(pub_key_hex).map_err(|_| PolicyError::InvalidSignature)?;
            let verifying_key = VerifyingKey::from_bytes(
                pub_key_bytes
                    .as_slice()
                    .try_into()
                    .map_err(|_| PolicyError::InvalidSignatureBytes)?,
            )
            .map_err(|_| PolicyError::InvalidSignatureBytes)?;

            let message = format!("{}:{}:{}", policy.policy_id, policy.version, policy.hash);
            let sig_bytes = hex::decode(sig_hex).map_err(|_| PolicyError::InvalidSignatureBytes)?;
            let signature = Signature::from_bytes(
                sig_bytes
                    .as_slice()
                    .try_into()
                    .map_err(|_| PolicyError::InvalidSignatureBytes)?,
            );

            verifying_key
                .verify_strict(message.as_bytes(), &signature)
                .map_err(|_| PolicyError::InvalidSignature)?;
        }

        Ok(())
    }
}

#[derive(Clone)]
pub struct PolicyEnforcer {
    policy: ActivePolicy,
    allow_override: bool,
}

impl PolicyEnforcer {
    pub fn new(policy: ActivePolicy) -> Self {
        let allow_override = std::env::var("SEKUIRE_POLICY_DEV_OVERRIDE")
            .map(|v| v == "true")
            .unwrap_or(false);
        Self::new_with_override(policy, allow_override)
    }

    pub fn new_with_override(policy: ActivePolicy, allow_override: bool) -> Self {
        Self {
            policy,
            allow_override,
        }
    }

    pub fn enforce_network(&self, domain: &str, protocol: &str) -> Result<(), PolicyError> {
        let perms = self
            .policy
            .content
            .get("permissions")
            .and_then(|p| p.get("network"));
        if let Some(perms) = perms {
            if !perms
                .get("enabled")
                .and_then(Value::as_bool)
                .unwrap_or(false)
            {
                return self.violate("network.disabled");
            }
            if perms
                .get("require_tls")
                .and_then(Value::as_bool)
                .unwrap_or(false)
                && protocol != "https"
            {
                return self.violate("network.tls_required");
            }
            if let Some(blocked) = perms.get("blocked_domains").and_then(Value::as_array) {
                if self.matches(domain, blocked) {
                    return self.violate("network.blocked");
                }
            }
            if let Some(allowed) = perms.get("allowed_domains").and_then(Value::as_array) {
                if !allowed.is_empty() && !self.matches(domain, allowed) {
                    return self.violate("network.not_allowed");
                }
            }
        }
        Ok(())
    }

    pub fn enforce_filesystem(&self, path: &str) -> Result<(), PolicyError> {
        let perms = self
            .policy
            .content
            .get("permissions")
            .and_then(|p| p.get("filesystem"));
        if let Some(perms) = perms {
            if !perms
                .get("enabled")
                .and_then(Value::as_bool)
                .unwrap_or(false)
            {
                return self.violate("filesystem.disabled");
            }
            if let Some(blocked) = perms.get("blocked_paths").and_then(Value::as_array) {
                if self.path_matches(path, blocked) {
                    return self.violate("filesystem.blocked");
                }
            }
            if let Some(allowed) = perms.get("allowed_paths").and_then(Value::as_array) {
                if !self.path_matches(path, allowed) {
                    return self.violate("filesystem.not_allowed");
                }
            }
        }
        Ok(())
    }

    pub fn enforce_tool(&self, tool: &str) -> Result<(), PolicyError> {
        if let Some(tools) = self.policy.content.get("tools") {
            if tools
                .get("blocked_tools")
                .and_then(Value::as_array)
                .is_some_and(|blocked| blocked.iter().any(|b| b == tool))
            {
                return self.violate("tool.blocked");
            }
            if let Some(allowed) = tools.get("allowed_tools").and_then(Value::as_array) {
                let ok = allowed
                    .iter()
                    .any(|t| t.get("name").is_some_and(|n| n == tool));
                if !ok {
                    return self.violate("tool.not_allowed");
                }
            }
        }
        Ok(())
    }

    pub fn enforce_model(&self, model: &str) -> Result<(), PolicyError> {
        if let Some(models) = self
            .policy
            .content
            .get("agent")
            .and_then(|a| a.get("models"))
        {
            if let Some(blocked) = models.get("blocked_models").and_then(Value::as_array) {
                if blocked.iter().any(|m| m == model) {
                    return self.violate("model.blocked");
                }
            }
            if let Some(allowed) = models.get("allowed_models").and_then(Value::as_array) {
                if !allowed.iter().any(|m| m == model) {
                    return self.violate("model.not_allowed");
                }
            }
        }
        Ok(())
    }

    pub fn enforce_api(&self, service: &str) -> Result<(), PolicyError> {
        let perms = self
            .policy
            .content
            .get("permissions")
            .and_then(|p| p.get("api"));
        if let Some(perms) = perms {
            if !perms
                .get("enabled")
                .and_then(Value::as_bool)
                .unwrap_or(false)
            {
                return self.violate("api.disabled");
            }
            if let Some(allowed) = perms.get("allowed_services").and_then(Value::as_array) {
                let allowed = allowed
                    .iter()
                    .any(|s| s.get("service_name").is_some_and(|n| n == service));
                if !allowed {
                    return self.violate("api.not_allowed");
                }
            }
        }
        Ok(())
    }

    pub fn enforce_rate_limit(&self, _limit_type: &str, _count: u32) -> Result<(), PolicyError> {
        if let Some(limits) = self
            .policy
            .content
            .get("rate_limits")
            .and_then(|l| l.get("per_agent"))
        {
            // Simple placeholder for rate limit check
            // In a real implementation, this would track usage and enforce limits
            if limits.get("requests_per_minute").is_some() {
                // Check limit
            }
        }
        Ok(())
    }

    fn violate(&self, rule: &str) -> Result<(), PolicyError> {
        if self.allow_override {
            println!("[policy][override] {}", rule);
            Ok(())
        } else {
            Err(PolicyError::Violation(rule.to_string()))
        }
    }

    fn matches(&self, domain: &str, patterns: &[Value]) -> bool {
        patterns.iter().filter_map(Value::as_str).any(|pattern| {
            if pattern == "*" {
                true
            } else if let Some(stripped) = pattern.strip_prefix("*.") {
                // Match both the base domain and subdomains (consistent with TS/Python)
                domain == stripped || domain.ends_with(&format!(".{}", stripped))
            } else {
                domain == pattern
            }
        })
    }

    fn path_matches(&self, path: &str, patterns: &[Value]) -> bool {
        patterns.iter().filter_map(Value::as_str).any(|pattern| {
            if let Some(prefix) = pattern.strip_suffix("/*") {
                path.starts_with(prefix)
            } else if let Some(prefix) = pattern.strip_suffix('*') {
                path.starts_with(prefix)
            } else {
                path == pattern
            }
        })
    }
}