1use std::sync::Arc;
16
17use cxx::SharedPtr;
18use parking_lot::Mutex;
19use gosuto_webrtc_sys::frame_cryptor::{self as sys_fc};
20
21use crate::{
22 peer_connection_factory::PeerConnectionFactory, rtp_receiver::RtpReceiver,
23 rtp_sender::RtpSender,
24};
25
26pub type OnStateChange = Box<dyn FnMut(String, EncryptionState) + Send + Sync>;
27
28#[derive(Debug, Clone, Copy, PartialEq, Eq)]
29pub enum KeyDerivationAlgorithm {
30 Pbkdf2,
31 Hkdf,
32}
33
34#[derive(Debug, Clone)]
35pub struct KeyProviderOptions {
36 pub shared_key: bool,
37 pub ratchet_window_size: i32,
38 pub ratchet_salt: Vec<u8>,
39 pub failure_tolerance: i32,
40 pub key_derivation_algorithm: KeyDerivationAlgorithm,
41}
42
43#[derive(Debug, Clone, Copy, PartialEq, Eq)]
44pub enum EncryptionAlgorithm {
45 AesGcm,
46 AesCbc,
47}
48
49#[derive(Debug, Clone, Copy, PartialEq, Eq)]
50pub enum EncryptionState {
51 New,
52 Ok,
53 EncryptionFailed,
54 DecryptionFailed,
55 MissingKey,
56 KeyRatcheted,
57 InternalError,
58}
59
60#[derive(Debug, Clone)]
61pub struct EncryptedPacket {
62 pub data: Vec<u8>,
63 pub iv: Vec<u8>,
64 pub key_index: u32,
65}
66
67#[derive(Clone)]
68pub struct KeyProvider {
69 pub(crate) sys_handle: SharedPtr<sys_fc::ffi::KeyProvider>,
70}
71
72impl KeyProvider {
73 pub fn new(options: KeyProviderOptions) -> Self {
74 Self { sys_handle: sys_fc::ffi::new_key_provider(options.into()) }
75 }
76
77 pub fn set_shared_key(&self, key_index: i32, key: Vec<u8>) -> bool {
78 self.sys_handle.set_shared_key(key_index, key)
79 }
80
81 pub fn ratchet_shared_key(&self, key_index: i32) -> Option<Vec<u8>> {
82 self.sys_handle.ratchet_shared_key(key_index).ok()
83 }
84
85 pub fn get_shared_key(&self, key_index: i32) -> Option<Vec<u8>> {
86 self.sys_handle.get_shared_key(key_index).ok()
87 }
88
89 pub fn set_key(&self, participant_id: String, key_index: i32, key: Vec<u8>) -> bool {
90 self.sys_handle.set_key(participant_id, key_index, key)
91 }
92
93 pub fn ratchet_key(&self, participant_id: String, key_index: i32) -> Option<Vec<u8>> {
94 self.sys_handle.ratchet_key(participant_id, key_index).ok()
95 }
96
97 pub fn get_key(&self, participant_id: String, key_index: i32) -> Option<Vec<u8>> {
98 self.sys_handle.get_key(participant_id, key_index).ok()
99 }
100
101 pub fn set_sif_trailer(&self, trailer: Vec<u8>) {
102 self.sys_handle.set_sif_trailer(trailer);
103 }
104}
105
106#[derive(Clone)]
107pub struct FrameCryptor {
108 observer: Arc<RtcFrameCryptorObserver>,
109 pub(crate) sys_handle: SharedPtr<sys_fc::ffi::FrameCryptor>,
110}
111
112impl FrameCryptor {
113 pub fn new_for_rtp_sender(
114 peer_factory: &PeerConnectionFactory,
115 participant_id: String,
116 algorithm: EncryptionAlgorithm,
117 key_provider: KeyProvider,
118 sender: RtpSender,
119 ) -> Self {
120 let observer = Arc::new(RtcFrameCryptorObserver::default());
121 let sys_handle = sys_fc::ffi::new_frame_cryptor_for_rtp_sender(
122 peer_factory.handle.sys_handle.clone(),
123 participant_id,
124 algorithm.into(),
125 key_provider.sys_handle,
126 sender.handle.sys_handle,
127 );
128 let fc = Self { observer: observer.clone(), sys_handle: sys_handle.clone() };
129 fc.sys_handle
130 .register_observer(Box::new(sys_fc::RtcFrameCryptorObserverWrapper::new(observer)));
131 fc
132 }
133
134 pub fn new_for_rtp_receiver(
135 peer_factory: &PeerConnectionFactory,
136 participant_id: String,
137 algorithm: EncryptionAlgorithm,
138 key_provider: KeyProvider,
139 receiver: RtpReceiver,
140 ) -> Self {
141 let observer = Arc::new(RtcFrameCryptorObserver::default());
142 let sys_handle = sys_fc::ffi::new_frame_cryptor_for_rtp_receiver(
143 peer_factory.handle.sys_handle.clone(),
144 participant_id,
145 algorithm.into(),
146 key_provider.sys_handle,
147 receiver.handle.sys_handle,
148 );
149 let fc = Self { observer: observer.clone(), sys_handle: sys_handle.clone() };
150 fc.sys_handle
151 .register_observer(Box::new(sys_fc::RtcFrameCryptorObserverWrapper::new(observer)));
152 fc
153 }
154
155 pub fn set_enabled(self: &FrameCryptor, enabled: bool) {
156 self.sys_handle.set_enabled(enabled);
157 }
158
159 pub fn enabled(self: &FrameCryptor) -> bool {
160 self.sys_handle.enabled()
161 }
162
163 pub fn set_key_index(self: &FrameCryptor, index: i32) {
164 self.sys_handle.set_key_index(index);
165 }
166
167 pub fn key_index(self: &FrameCryptor) -> i32 {
168 self.sys_handle.key_index()
169 }
170
171 pub fn participant_id(self: &FrameCryptor) -> String {
172 self.sys_handle.participant_id()
173 }
174
175 pub fn on_state_change(&self, handler: Option<OnStateChange>) {
176 *self.observer.state_change_handler.lock() = handler;
177 }
178}
179
180#[derive(Clone)]
181pub struct DataPacketCryptor {
182 pub(crate) sys_handle: SharedPtr<sys_fc::ffi::DataPacketCryptor>,
183}
184
185impl DataPacketCryptor {
186 pub fn new(algorithm: EncryptionAlgorithm, key_provider: KeyProvider) -> Self {
187 Self {
188 sys_handle: sys_fc::ffi::new_data_packet_cryptor(
189 algorithm.into(),
190 key_provider.sys_handle,
191 ),
192 }
193 }
194
195 pub fn encrypt(
196 &self,
197 participant_id: &str,
198 key_index: u32,
199 data: &[u8],
200 ) -> Result<EncryptedPacket, Box<dyn std::error::Error>> {
201 let data_vec: Vec<u8> = data.to_vec();
202 match self.sys_handle.encrypt_data_packet(participant_id.to_string(), key_index, data_vec) {
203 Ok(packet) => Ok(packet.into()),
204 Err(e) => Err(format!("Encryption failed: {}", e).into()),
205 }
206 }
207
208 pub fn decrypt(
209 &self,
210 participant_id: &str,
211 encrypted_packet: &EncryptedPacket,
212 ) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
213 match self
214 .sys_handle
215 .decrypt_data_packet(participant_id.to_string(), &encrypted_packet.clone().into())
216 {
217 Ok(data) => Ok(data.into_iter().collect()),
218 Err(e) => Err(format!("Decryption failed: {}", e).into()),
219 }
220 }
221}
222
223#[derive(Default)]
224struct RtcFrameCryptorObserver {
225 state_change_handler: Mutex<Option<OnStateChange>>,
226}
227
228impl sys_fc::RtcFrameCryptorObserver for RtcFrameCryptorObserver {
229 fn on_frame_cryption_state_change(
230 &self,
231 participant_id: String,
232 state: sys_fc::ffi::FrameCryptionState,
233 ) {
234 let mut handler = self.state_change_handler.lock();
235 if let Some(f) = handler.as_mut() {
236 f(participant_id, state.into());
237 }
238 }
239}
240
241impl From<sys_fc::ffi::Algorithm> for EncryptionAlgorithm {
242 fn from(value: sys_fc::ffi::Algorithm) -> Self {
243 match value {
244 sys_fc::ffi::Algorithm::AesGcm => Self::AesGcm,
245 sys_fc::ffi::Algorithm::AesCbc => Self::AesCbc,
246 _ => panic!("unknown frame cyrptor Algorithm"),
247 }
248 }
249}
250
251impl From<EncryptionAlgorithm> for sys_fc::ffi::Algorithm {
252 fn from(value: EncryptionAlgorithm) -> Self {
253 match value {
254 EncryptionAlgorithm::AesGcm => Self::AesGcm,
255 EncryptionAlgorithm::AesCbc => Self::AesCbc,
256 }
257 }
258}
259
260impl From<sys_fc::ffi::FrameCryptionState> for EncryptionState {
261 fn from(value: sys_fc::ffi::FrameCryptionState) -> Self {
262 match value {
263 sys_fc::ffi::FrameCryptionState::New => Self::New,
264 sys_fc::ffi::FrameCryptionState::Ok => Self::Ok,
265 sys_fc::ffi::FrameCryptionState::EncryptionFailed => Self::EncryptionFailed,
266 sys_fc::ffi::FrameCryptionState::DecryptionFailed => Self::DecryptionFailed,
267 sys_fc::ffi::FrameCryptionState::MissingKey => Self::MissingKey,
268 sys_fc::ffi::FrameCryptionState::KeyRatcheted => Self::KeyRatcheted,
269 sys_fc::ffi::FrameCryptionState::InternalError => Self::InternalError,
270 _ => panic!("unknown frame cyrptor FrameCryptionState"),
271 }
272 }
273}
274
275impl From<KeyDerivationAlgorithm> for sys_fc::ffi::KeyDerivationAlgorithm {
276 fn from(value: KeyDerivationAlgorithm) -> Self {
277 match value {
278 KeyDerivationAlgorithm::Pbkdf2 => Self::Pbkdf2,
279 KeyDerivationAlgorithm::Hkdf => Self::Hkdf,
280 }
281 }
282}
283
284impl From<sys_fc::ffi::KeyDerivationAlgorithm> for KeyDerivationAlgorithm {
285 fn from(value: sys_fc::ffi::KeyDerivationAlgorithm) -> Self {
286 match value {
287 sys_fc::ffi::KeyDerivationAlgorithm::Pbkdf2 => Self::Pbkdf2,
288 sys_fc::ffi::KeyDerivationAlgorithm::Hkdf => Self::Hkdf,
289 _ => panic!("unknown KeyDerivationAlgorithm"),
290 }
291 }
292}
293
294impl From<KeyProviderOptions> for sys_fc::ffi::KeyProviderOptions {
295 fn from(value: KeyProviderOptions) -> Self {
296 Self {
297 shared_key: value.shared_key,
298 ratchet_window_size: value.ratchet_window_size,
299 ratchet_salt: value.ratchet_salt,
300 failure_tolerance: value.failure_tolerance,
301 key_derivation_algorithm: value.key_derivation_algorithm.into(),
302 }
303 }
304}
305
306impl From<sys_fc::ffi::EncryptedPacket> for EncryptedPacket {
307 fn from(value: sys_fc::ffi::EncryptedPacket) -> Self {
308 Self {
309 data: value.data.into_iter().collect(),
310 iv: value.iv.into_iter().collect(),
311 key_index: value.key_index,
312 }
313 }
314}
315
316impl From<EncryptedPacket> for sys_fc::ffi::EncryptedPacket {
317 fn from(value: EncryptedPacket) -> Self {
318 Self { data: value.data, iv: value.iv, key_index: value.key_index }
319 }
320}