use crate::error::{Error, Result};
use crate::types::DrmSystem;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use url::Url;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PsshBox {
pub system_id: String,
pub key_ids: Vec<String>,
pub data: String,
}
impl PsshBox {
pub fn new(system_id: &str, data: &[u8]) -> Self {
Self {
system_id: system_id.to_string(),
key_ids: Vec::new(),
data: base64_encode(data),
}
}
pub fn drm_system(&self) -> Option<DrmSystem> {
match self.system_id.to_lowercase().as_str() {
"edef8ba9-79d6-4ace-a3c8-27dcd51d21ed" => Some(DrmSystem::Widevine),
"94ce86fb-07ff-4f43-adb8-93d2fa968ca2" => Some(DrmSystem::FairPlay),
"9a04f079-9840-4286-ab92-e65be0885f95" => Some(DrmSystem::PlayReady),
"1077efec-c0b2-4d02-ace3-3c1e52e2fb4b" => Some(DrmSystem::ClearKey),
_ => None,
}
}
pub fn data_bytes(&self) -> Result<Vec<u8>> {
base64_decode(&self.data)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DrmConfig {
pub widevine_license_url: Option<Url>,
pub playready_license_url: Option<Url>,
pub fairplay_certificate_url: Option<Url>,
pub fairplay_license_url: Option<Url>,
pub license_headers: HashMap<String, String>,
pub fairplay_content_id: Option<String>,
pub clearkey_keys: HashMap<String, String>,
pub persist_license: bool,
pub license_duration: u64,
}
impl Default for DrmConfig {
fn default() -> Self {
Self {
widevine_license_url: None,
playready_license_url: None,
fairplay_certificate_url: None,
fairplay_license_url: None,
license_headers: HashMap::new(),
fairplay_content_id: None,
clearkey_keys: HashMap::new(),
persist_license: false,
license_duration: 0,
}
}
}
impl DrmConfig {
pub fn widevine(license_url: Url) -> Self {
Self {
widevine_license_url: Some(license_url),
..Default::default()
}
}
pub fn fairplay(license_url: Url, certificate_url: Url) -> Self {
Self {
fairplay_license_url: Some(license_url),
fairplay_certificate_url: Some(certificate_url),
..Default::default()
}
}
pub fn clearkey(keys: HashMap<String, String>) -> Self {
Self {
clearkey_keys: keys,
..Default::default()
}
}
pub fn with_header(mut self, key: &str, value: &str) -> Self {
self.license_headers.insert(key.to_string(), value.to_string());
self
}
pub fn is_configured(&self) -> bool {
self.widevine_license_url.is_some()
|| self.playready_license_url.is_some()
|| self.fairplay_license_url.is_some()
|| !self.clearkey_keys.is_empty()
}
pub fn supported_systems(&self) -> Vec<DrmSystem> {
let mut systems = Vec::new();
if self.widevine_license_url.is_some() {
systems.push(DrmSystem::Widevine);
}
if self.playready_license_url.is_some() {
systems.push(DrmSystem::PlayReady);
}
if self.fairplay_license_url.is_some() {
systems.push(DrmSystem::FairPlay);
}
if !self.clearkey_keys.is_empty() {
systems.push(DrmSystem::ClearKey);
}
systems
}
}
#[derive(Debug, Clone)]
pub struct LicenseRequest {
pub system: DrmSystem,
pub challenge: Vec<u8>,
pub license_url: Url,
pub headers: HashMap<String, String>,
}
#[derive(Debug, Clone)]
pub struct LicenseResponse {
pub system: DrmSystem,
pub license: Vec<u8>,
pub expiration: u64,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DrmSessionState {
Idle,
AwaitingCertificate,
GeneratingChallenge,
AwaitingLicense,
Ready,
Expired,
Error,
}
#[derive(Debug, Clone)]
pub struct DrmSession {
pub id: String,
pub system: DrmSystem,
pub state: DrmSessionState,
pub key_ids: Vec<String>,
pub expiration: u64,
pub error: Option<String>,
}
impl DrmSession {
pub fn new(system: DrmSystem) -> Self {
Self {
id: uuid::Uuid::new_v4().to_string(),
system,
state: DrmSessionState::Idle,
key_ids: Vec::new(),
expiration: 0,
error: None,
}
}
pub fn is_ready(&self) -> bool {
self.state == DrmSessionState::Ready
}
pub fn is_expired(&self) -> bool {
if self.expiration == 0 {
return false;
}
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0);
now >= self.expiration
}
}
pub struct DrmManager {
config: DrmConfig,
sessions: HashMap<String, DrmSession>,
pssh_boxes: Vec<PsshBox>,
}
impl DrmManager {
pub fn new(config: DrmConfig) -> Self {
Self {
config,
sessions: HashMap::new(),
pssh_boxes: Vec::new(),
}
}
pub fn set_pssh_boxes(&mut self, boxes: Vec<PsshBox>) {
self.pssh_boxes = boxes;
}
pub fn get_pssh(&self, system: DrmSystem) -> Option<&PsshBox> {
let target_id = system.system_id().to_lowercase();
self.pssh_boxes.iter().find(|p| p.system_id.to_lowercase() == target_id)
}
pub fn create_widevine_request(&self, challenge: Vec<u8>) -> Result<LicenseRequest> {
let license_url = self.config.widevine_license_url.clone()
.ok_or_else(|| Error::drm("Widevine license URL not configured"))?;
Ok(LicenseRequest {
system: DrmSystem::Widevine,
challenge,
license_url,
headers: self.config.license_headers.clone(),
})
}
pub fn create_fairplay_request(&self, spc: Vec<u8>) -> Result<LicenseRequest> {
let license_url = self.config.fairplay_license_url.clone()
.ok_or_else(|| Error::drm("FairPlay license URL not configured"))?;
Ok(LicenseRequest {
system: DrmSystem::FairPlay,
challenge: spc,
license_url,
headers: self.config.license_headers.clone(),
})
}
pub fn get_clearkey_license(&self) -> Result<LicenseResponse> {
if self.config.clearkey_keys.is_empty() {
return Err(Error::drm("No ClearKey keys configured"));
}
let keys: Vec<serde_json::Value> = self.config.clearkey_keys.iter()
.map(|(kid, key)| {
serde_json::json!({
"kty": "oct",
"kid": kid,
"k": key,
})
})
.collect();
let license_json = serde_json::json!({
"keys": keys,
"type": "temporary",
});
Ok(LicenseResponse {
system: DrmSystem::ClearKey,
license: license_json.to_string().into_bytes(),
expiration: 0,
})
}
pub fn create_session(&mut self, system: DrmSystem) -> &DrmSession {
let session = DrmSession::new(system);
let id = session.id.clone();
self.sessions.insert(id.clone(), session);
self.sessions.get(&id).unwrap()
}
pub fn process_license(&mut self, session_id: &str, response: LicenseResponse) -> Result<()> {
let session = self.sessions.get_mut(session_id)
.ok_or_else(|| Error::drm("Session not found"))?;
session.state = DrmSessionState::Ready;
session.expiration = response.expiration;
Ok(())
}
pub fn sessions(&self) -> impl Iterator<Item = &DrmSession> {
self.sessions.values()
}
pub fn get_session(&self, id: &str) -> Option<&DrmSession> {
self.sessions.get(id)
}
pub fn close_session(&mut self, id: &str) {
self.sessions.remove(id);
}
pub fn close_all_sessions(&mut self) {
self.sessions.clear();
}
pub fn is_drm_required(&self) -> bool {
!self.pssh_boxes.is_empty()
}
pub fn select_drm_system(&self) -> Option<DrmSystem> {
let supported = self.config.supported_systems();
for system in &[DrmSystem::Widevine, DrmSystem::FairPlay, DrmSystem::PlayReady, DrmSystem::ClearKey] {
if supported.contains(system) && self.get_pssh(*system).is_some() {
return Some(*system);
}
}
if !self.config.clearkey_keys.is_empty() {
return Some(DrmSystem::ClearKey);
}
None
}
}
fn base64_encode(data: &[u8]) -> String {
const ALPHABET: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
let mut result = String::new();
for chunk in data.chunks(3) {
let b = match chunk.len() {
3 => [chunk[0], chunk[1], chunk[2], 0],
2 => [chunk[0], chunk[1], 0, 0],
1 => [chunk[0], 0, 0, 0],
_ => continue,
};
let n = ((b[0] as u32) << 16) | ((b[1] as u32) << 8) | (b[2] as u32);
result.push(ALPHABET[((n >> 18) & 0x3F) as usize] as char);
result.push(ALPHABET[((n >> 12) & 0x3F) as usize] as char);
result.push(if chunk.len() > 1 { ALPHABET[((n >> 6) & 0x3F) as usize] as char } else { '=' });
result.push(if chunk.len() > 2 { ALPHABET[(n & 0x3F) as usize] as char } else { '=' });
}
result
}
fn base64_decode(data: &str) -> Result<Vec<u8>> {
const DECODE_TABLE: &[i8; 128] = &[
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 62, -1, -1, -1, 63,
52, 53, 54, 55, 56, 57, 58, 59, 60, 61, -1, -1, -1, -1, -1, -1,
-1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14,
15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, -1, -1, -1, -1, -1,
-1, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40,
41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, -1, -1, -1, -1, -1,
];
let input: Vec<u8> = data.bytes()
.filter(|b| *b != b'=' && *b != b'\n' && *b != b'\r')
.collect();
let mut result = Vec::with_capacity(input.len() * 3 / 4);
for chunk in input.chunks(4) {
let mut n: u32 = 0;
let chunk_len = chunk.len();
for (i, &b) in chunk.iter().enumerate() {
if b as usize >= 128 {
return Err(Error::drm("Invalid base64 character"));
}
let val = DECODE_TABLE[b as usize];
if val < 0 {
return Err(Error::drm("Invalid base64 character"));
}
n |= (val as u32) << (18 - i * 6);
}
result.push((n >> 16) as u8);
if chunk_len > 2 {
result.push((n >> 8) as u8);
}
if chunk_len > 3 {
result.push(n as u8);
}
}
Ok(result)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_drm_config() {
let config = DrmConfig::default();
assert!(!config.is_configured());
let config = DrmConfig::widevine(Url::parse("https://license.example.com").unwrap());
assert!(config.is_configured());
assert!(config.supported_systems().contains(&DrmSystem::Widevine));
}
#[test]
fn test_pssh_box() {
let pssh = PsshBox::new(DrmSystem::Widevine.system_id(), b"test data");
assert_eq!(pssh.drm_system(), Some(DrmSystem::Widevine));
}
#[test]
fn test_base64_roundtrip() {
let original = b"Hello, DRM!";
let encoded = base64_encode(original);
let decoded = base64_decode(&encoded).unwrap();
assert_eq!(original.to_vec(), decoded);
}
#[test]
fn test_clearkey_license() {
let mut keys = HashMap::new();
keys.insert("abc123".to_string(), "key456".to_string());
let config = DrmConfig::clearkey(keys);
let manager = DrmManager::new(config);
let license = manager.get_clearkey_license().unwrap();
assert_eq!(license.system, DrmSystem::ClearKey);
}
}