Skip to main content

gosuto_libwebrtc/native/
frame_cryptor.rs

1// Copyright 2025 LiveKit, Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::Arc;
16
17use cxx::SharedPtr;
18use parking_lot::Mutex;
19use gosuto_webrtc_sys::frame_cryptor::{self as sys_fc};
20
21use crate::{
22    peer_connection_factory::PeerConnectionFactory, rtp_receiver::RtpReceiver,
23    rtp_sender::RtpSender,
24};
25
26pub type OnStateChange = Box<dyn FnMut(String, EncryptionState) + Send + Sync>;
27
28#[derive(Debug, Clone, Copy, PartialEq, Eq)]
29pub enum KeyDerivationAlgorithm {
30    Pbkdf2,
31    Hkdf,
32}
33
34#[derive(Debug, Clone)]
35pub struct KeyProviderOptions {
36    pub shared_key: bool,
37    pub ratchet_window_size: i32,
38    pub ratchet_salt: Vec<u8>,
39    pub failure_tolerance: i32,
40    pub key_derivation_algorithm: KeyDerivationAlgorithm,
41}
42
43#[derive(Debug, Clone, Copy, PartialEq, Eq)]
44pub enum EncryptionAlgorithm {
45    AesGcm,
46    AesCbc,
47}
48
49#[derive(Debug, Clone, Copy, PartialEq, Eq)]
50pub enum EncryptionState {
51    New,
52    Ok,
53    EncryptionFailed,
54    DecryptionFailed,
55    MissingKey,
56    KeyRatcheted,
57    InternalError,
58}
59
60#[derive(Debug, Clone)]
61pub struct EncryptedPacket {
62    pub data: Vec<u8>,
63    pub iv: Vec<u8>,
64    pub key_index: u32,
65}
66
67#[derive(Clone)]
68pub struct KeyProvider {
69    pub(crate) sys_handle: SharedPtr<sys_fc::ffi::KeyProvider>,
70}
71
72impl KeyProvider {
73    pub fn new(options: KeyProviderOptions) -> Self {
74        Self { sys_handle: sys_fc::ffi::new_key_provider(options.into()) }
75    }
76
77    pub fn set_shared_key(&self, key_index: i32, key: Vec<u8>) -> bool {
78        self.sys_handle.set_shared_key(key_index, key)
79    }
80
81    pub fn ratchet_shared_key(&self, key_index: i32) -> Option<Vec<u8>> {
82        self.sys_handle.ratchet_shared_key(key_index).ok()
83    }
84
85    pub fn get_shared_key(&self, key_index: i32) -> Option<Vec<u8>> {
86        self.sys_handle.get_shared_key(key_index).ok()
87    }
88
89    pub fn set_key(&self, participant_id: String, key_index: i32, key: Vec<u8>) -> bool {
90        self.sys_handle.set_key(participant_id, key_index, key)
91    }
92
93    pub fn ratchet_key(&self, participant_id: String, key_index: i32) -> Option<Vec<u8>> {
94        self.sys_handle.ratchet_key(participant_id, key_index).ok()
95    }
96
97    pub fn get_key(&self, participant_id: String, key_index: i32) -> Option<Vec<u8>> {
98        self.sys_handle.get_key(participant_id, key_index).ok()
99    }
100
101    pub fn set_sif_trailer(&self, trailer: Vec<u8>) {
102        self.sys_handle.set_sif_trailer(trailer);
103    }
104}
105
106#[derive(Clone)]
107pub struct FrameCryptor {
108    observer: Arc<RtcFrameCryptorObserver>,
109    pub(crate) sys_handle: SharedPtr<sys_fc::ffi::FrameCryptor>,
110}
111
112impl FrameCryptor {
113    pub fn new_for_rtp_sender(
114        peer_factory: &PeerConnectionFactory,
115        participant_id: String,
116        algorithm: EncryptionAlgorithm,
117        key_provider: KeyProvider,
118        sender: RtpSender,
119    ) -> Self {
120        let observer = Arc::new(RtcFrameCryptorObserver::default());
121        let sys_handle = sys_fc::ffi::new_frame_cryptor_for_rtp_sender(
122            peer_factory.handle.sys_handle.clone(),
123            participant_id,
124            algorithm.into(),
125            key_provider.sys_handle,
126            sender.handle.sys_handle,
127        );
128        let fc = Self { observer: observer.clone(), sys_handle: sys_handle.clone() };
129        fc.sys_handle
130            .register_observer(Box::new(sys_fc::RtcFrameCryptorObserverWrapper::new(observer)));
131        fc
132    }
133
134    pub fn new_for_rtp_receiver(
135        peer_factory: &PeerConnectionFactory,
136        participant_id: String,
137        algorithm: EncryptionAlgorithm,
138        key_provider: KeyProvider,
139        receiver: RtpReceiver,
140    ) -> Self {
141        let observer = Arc::new(RtcFrameCryptorObserver::default());
142        let sys_handle = sys_fc::ffi::new_frame_cryptor_for_rtp_receiver(
143            peer_factory.handle.sys_handle.clone(),
144            participant_id,
145            algorithm.into(),
146            key_provider.sys_handle,
147            receiver.handle.sys_handle,
148        );
149        let fc = Self { observer: observer.clone(), sys_handle: sys_handle.clone() };
150        fc.sys_handle
151            .register_observer(Box::new(sys_fc::RtcFrameCryptorObserverWrapper::new(observer)));
152        fc
153    }
154
155    pub fn set_enabled(self: &FrameCryptor, enabled: bool) {
156        self.sys_handle.set_enabled(enabled);
157    }
158
159    pub fn enabled(self: &FrameCryptor) -> bool {
160        self.sys_handle.enabled()
161    }
162
163    pub fn set_key_index(self: &FrameCryptor, index: i32) {
164        self.sys_handle.set_key_index(index);
165    }
166
167    pub fn key_index(self: &FrameCryptor) -> i32 {
168        self.sys_handle.key_index()
169    }
170
171    pub fn participant_id(self: &FrameCryptor) -> String {
172        self.sys_handle.participant_id()
173    }
174
175    pub fn on_state_change(&self, handler: Option<OnStateChange>) {
176        *self.observer.state_change_handler.lock() = handler;
177    }
178}
179
180#[derive(Clone)]
181pub struct DataPacketCryptor {
182    pub(crate) sys_handle: SharedPtr<sys_fc::ffi::DataPacketCryptor>,
183}
184
185impl DataPacketCryptor {
186    pub fn new(algorithm: EncryptionAlgorithm, key_provider: KeyProvider) -> Self {
187        Self {
188            sys_handle: sys_fc::ffi::new_data_packet_cryptor(
189                algorithm.into(),
190                key_provider.sys_handle,
191            ),
192        }
193    }
194
195    pub fn encrypt(
196        &self,
197        participant_id: &str,
198        key_index: u32,
199        data: &[u8],
200    ) -> Result<EncryptedPacket, Box<dyn std::error::Error>> {
201        let data_vec: Vec<u8> = data.to_vec();
202        match self.sys_handle.encrypt_data_packet(participant_id.to_string(), key_index, data_vec) {
203            Ok(packet) => Ok(packet.into()),
204            Err(e) => Err(format!("Encryption failed: {}", e).into()),
205        }
206    }
207
208    pub fn decrypt(
209        &self,
210        participant_id: &str,
211        encrypted_packet: &EncryptedPacket,
212    ) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
213        match self
214            .sys_handle
215            .decrypt_data_packet(participant_id.to_string(), &encrypted_packet.clone().into())
216        {
217            Ok(data) => Ok(data.into_iter().collect()),
218            Err(e) => Err(format!("Decryption failed: {}", e).into()),
219        }
220    }
221}
222
223#[derive(Default)]
224struct RtcFrameCryptorObserver {
225    state_change_handler: Mutex<Option<OnStateChange>>,
226}
227
228impl sys_fc::RtcFrameCryptorObserver for RtcFrameCryptorObserver {
229    fn on_frame_cryption_state_change(
230        &self,
231        participant_id: String,
232        state: sys_fc::ffi::FrameCryptionState,
233    ) {
234        let mut handler = self.state_change_handler.lock();
235        if let Some(f) = handler.as_mut() {
236            f(participant_id, state.into());
237        }
238    }
239}
240
241impl From<sys_fc::ffi::Algorithm> for EncryptionAlgorithm {
242    fn from(value: sys_fc::ffi::Algorithm) -> Self {
243        match value {
244            sys_fc::ffi::Algorithm::AesGcm => Self::AesGcm,
245            sys_fc::ffi::Algorithm::AesCbc => Self::AesCbc,
246            _ => panic!("unknown frame cyrptor Algorithm"),
247        }
248    }
249}
250
251impl From<EncryptionAlgorithm> for sys_fc::ffi::Algorithm {
252    fn from(value: EncryptionAlgorithm) -> Self {
253        match value {
254            EncryptionAlgorithm::AesGcm => Self::AesGcm,
255            EncryptionAlgorithm::AesCbc => Self::AesCbc,
256        }
257    }
258}
259
260impl From<sys_fc::ffi::FrameCryptionState> for EncryptionState {
261    fn from(value: sys_fc::ffi::FrameCryptionState) -> Self {
262        match value {
263            sys_fc::ffi::FrameCryptionState::New => Self::New,
264            sys_fc::ffi::FrameCryptionState::Ok => Self::Ok,
265            sys_fc::ffi::FrameCryptionState::EncryptionFailed => Self::EncryptionFailed,
266            sys_fc::ffi::FrameCryptionState::DecryptionFailed => Self::DecryptionFailed,
267            sys_fc::ffi::FrameCryptionState::MissingKey => Self::MissingKey,
268            sys_fc::ffi::FrameCryptionState::KeyRatcheted => Self::KeyRatcheted,
269            sys_fc::ffi::FrameCryptionState::InternalError => Self::InternalError,
270            _ => panic!("unknown frame cyrptor FrameCryptionState"),
271        }
272    }
273}
274
275impl From<KeyDerivationAlgorithm> for sys_fc::ffi::KeyDerivationAlgorithm {
276    fn from(value: KeyDerivationAlgorithm) -> Self {
277        match value {
278            KeyDerivationAlgorithm::Pbkdf2 => Self::Pbkdf2,
279            KeyDerivationAlgorithm::Hkdf => Self::Hkdf,
280        }
281    }
282}
283
284impl From<sys_fc::ffi::KeyDerivationAlgorithm> for KeyDerivationAlgorithm {
285    fn from(value: sys_fc::ffi::KeyDerivationAlgorithm) -> Self {
286        match value {
287            sys_fc::ffi::KeyDerivationAlgorithm::Pbkdf2 => Self::Pbkdf2,
288            sys_fc::ffi::KeyDerivationAlgorithm::Hkdf => Self::Hkdf,
289            _ => panic!("unknown KeyDerivationAlgorithm"),
290        }
291    }
292}
293
294impl From<KeyProviderOptions> for sys_fc::ffi::KeyProviderOptions {
295    fn from(value: KeyProviderOptions) -> Self {
296        Self {
297            shared_key: value.shared_key,
298            ratchet_window_size: value.ratchet_window_size,
299            ratchet_salt: value.ratchet_salt,
300            failure_tolerance: value.failure_tolerance,
301            key_derivation_algorithm: value.key_derivation_algorithm.into(),
302        }
303    }
304}
305
306impl From<sys_fc::ffi::EncryptedPacket> for EncryptedPacket {
307    fn from(value: sys_fc::ffi::EncryptedPacket) -> Self {
308        Self {
309            data: value.data.into_iter().collect(),
310            iv: value.iv.into_iter().collect(),
311            key_index: value.key_index,
312        }
313    }
314}
315
316impl From<EncryptedPacket> for sys_fc::ffi::EncryptedPacket {
317    fn from(value: EncryptedPacket) -> Self {
318        Self { data: value.data, iv: value.iv, key_index: value.key_index }
319    }
320}