Skip to main content

gosuto_livekit/room/e2ee/
manager.rs

1// Copyright 2025 LiveKit, Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::{collections::HashMap, sync::Arc};
16
17use gosuto_libwebrtc::{
18    native::frame_cryptor::{
19        DataPacketCryptor, EncryptedPacket, EncryptionAlgorithm, EncryptionState, FrameCryptor,
20    },
21    rtp_receiver::RtpReceiver,
22    rtp_sender::RtpSender,
23};
24use parking_lot::Mutex;
25
26use super::{key_provider::KeyProvider, EncryptionType};
27use crate::{
28    e2ee::E2eeOptions,
29    id::{ParticipantIdentity, TrackSid},
30    participant::{LocalParticipant, RemoteParticipant},
31    prelude::{LocalTrack, LocalTrackPublication, RemoteTrack, RemoteTrackPublication},
32    rtc_engine::lk_runtime::LkRuntime,
33};
34
35type StateChangedHandler = Box<dyn Fn(ParticipantIdentity, EncryptionState) + Send>;
36
37struct ManagerInner {
38    options: Option<E2eeOptions>, // If Some, it means the e2ee was initialized
39    enabled: bool,                // Used to enable/disable e2ee
40    dc_encryption_enabled: bool,
41    frame_cryptors: HashMap<(ParticipantIdentity, TrackSid), FrameCryptor>,
42    data_packet_cryptor: Option<DataPacketCryptor>,
43}
44
45#[derive(Clone)]
46pub struct E2eeManager {
47    inner: Arc<Mutex<ManagerInner>>,
48    state_changed: Arc<Mutex<Option<StateChangedHandler>>>,
49}
50
51impl E2eeManager {
52    /// E2eeOptions is an optional parameter. We may support to reconfigure e2ee after connect in
53    /// the future.
54    pub(crate) fn new(options: Option<E2eeOptions>, with_dc_encryption: bool) -> Self {
55        // Create DataPacketCryptor whenever E2EE options are available
56        // This allows for decryption even if we're not encrypting our own data
57        let data_packet_cryptor = options.as_ref().map(|opts| {
58            DataPacketCryptor::new(EncryptionAlgorithm::AesGcm, opts.key_provider.handle.clone())
59        });
60
61        Self {
62            inner: Arc::new(Mutex::new(ManagerInner {
63                enabled: options.is_some(), // Enabled by default if options is provided
64                dc_encryption_enabled: options.is_some() && with_dc_encryption,
65                options,
66                frame_cryptors: HashMap::new(),
67                data_packet_cryptor,
68            })),
69            state_changed: Default::default(),
70        }
71    }
72
73    pub(crate) fn cleanup(&self) {
74        let mut inner = self.inner.lock();
75        for cryptor in inner.frame_cryptors.values() {
76            cryptor.set_enabled(false);
77        }
78        inner.frame_cryptors.clear();
79    }
80
81    /// Register to e2ee state changes
82    /// Used by the room to dispatch the event to the room dispatcher
83    pub(crate) fn on_state_changed(
84        &self,
85        handler: impl Fn(ParticipantIdentity, EncryptionState) + Send + 'static,
86    ) {
87        *self.state_changed.lock() = Some(Box::new(handler));
88    }
89
90    pub(crate) fn initialized(&self) -> bool {
91        self.inner.lock().options.is_some()
92    }
93
94    /// Called by the room
95    pub(crate) fn on_track_subscribed(
96        &self,
97        track: RemoteTrack,
98        publication: RemoteTrackPublication,
99        participant: RemoteParticipant,
100    ) {
101        if !self.initialized() {
102            return;
103        }
104
105        if publication.encryption_type() == EncryptionType::None {
106            return;
107        }
108
109        let identity = participant.identity();
110        let receiver = track.transceiver().unwrap().receiver();
111        let frame_cryptor = self.setup_rtp_receiver(&identity, receiver);
112        self.setup_cryptor(&frame_cryptor);
113
114        let mut inner = self.inner.lock();
115        inner.frame_cryptors.insert((identity, publication.sid()), frame_cryptor.clone());
116    }
117
118    /// Called by the room
119    pub(crate) fn on_local_track_published(
120        &self,
121        track: LocalTrack,
122        publication: LocalTrackPublication,
123        participant: LocalParticipant,
124    ) {
125        if !self.initialized() {
126            return;
127        }
128
129        if publication.encryption_type() == EncryptionType::None {
130            return;
131        }
132
133        let identity = participant.identity();
134        let sender = track.transceiver().unwrap().sender();
135        let frame_cryptor = self.setup_rtp_sender(&identity, sender);
136        self.setup_cryptor(&frame_cryptor);
137
138        let mut inner = self.inner.lock();
139        inner.frame_cryptors.insert((identity, publication.sid()), frame_cryptor.clone());
140    }
141
142    fn setup_cryptor(&self, frame_cryptor: &FrameCryptor) {
143        let state_changed = self.state_changed.clone();
144        frame_cryptor.on_state_change(Some(Box::new(move |participant_identity, state| {
145            if let Some(state_changed) = state_changed.lock().as_ref() {
146                state_changed(participant_identity.into(), state);
147            }
148        })));
149    }
150
151    /// Called by the room
152    pub(crate) fn on_local_track_unpublished(
153        &self,
154        publication: LocalTrackPublication,
155        participant: LocalParticipant,
156    ) {
157        self.remove_frame_cryptor(participant.identity(), publication.sid());
158    }
159
160    /// Called by the room
161    pub(crate) fn on_track_unsubscribed(
162        &self,
163        _: RemoteTrack,
164        publication: RemoteTrackPublication,
165        participant: RemoteParticipant,
166    ) {
167        self.remove_frame_cryptor(participant.identity(), publication.sid());
168    }
169
170    pub fn frame_cryptors(&self) -> HashMap<(ParticipantIdentity, TrackSid), FrameCryptor> {
171        self.inner.lock().frame_cryptors.clone()
172    }
173
174    pub fn enabled(&self) -> bool {
175        self.inner.lock().enabled && self.initialized()
176    }
177
178    pub fn is_dc_encryption_enabled(&self) -> bool {
179        self.inner.lock().dc_encryption_enabled
180    }
181
182    pub fn set_enabled(&self, enabled: bool) {
183        let inner = self.inner.lock();
184        if inner.enabled == enabled {
185            return;
186        }
187
188        for (_, cryptor) in inner.frame_cryptors.iter() {
189            cryptor.set_enabled(enabled);
190        }
191    }
192
193    pub fn key_provider(&self) -> Option<KeyProvider> {
194        let inner = self.inner.lock();
195        inner.options.as_ref().map(|opts| opts.key_provider.clone())
196    }
197
198    pub fn encryption_type(&self) -> EncryptionType {
199        let inner = self.inner.lock();
200        inner.options.as_ref().map(|opts| opts.encryption_type).unwrap_or(EncryptionType::None)
201    }
202
203    fn setup_rtp_sender(
204        &self,
205        participant_identity: &ParticipantIdentity,
206        sender: RtpSender,
207    ) -> FrameCryptor {
208        let inner = self.inner.lock();
209        let options = inner.options.as_ref().unwrap();
210
211        let frame_cryptor = FrameCryptor::new_for_rtp_sender(
212            LkRuntime::instance().pc_factory(),
213            participant_identity.to_string(),
214            EncryptionAlgorithm::AesGcm,
215            options.key_provider.handle.clone(),
216            sender,
217        );
218        frame_cryptor.set_enabled(inner.enabled);
219        frame_cryptor
220    }
221
222    fn setup_rtp_receiver(
223        &self,
224        participant_identity: &ParticipantIdentity,
225        receiver: RtpReceiver,
226    ) -> FrameCryptor {
227        let inner = self.inner.lock();
228        let options = inner.options.as_ref().unwrap();
229
230        let frame_cryptor = FrameCryptor::new_for_rtp_receiver(
231            LkRuntime::instance().pc_factory(),
232            participant_identity.to_string(),
233            EncryptionAlgorithm::AesGcm,
234            options.key_provider.handle.clone(),
235            receiver,
236        );
237        frame_cryptor.set_enabled(inner.enabled);
238        frame_cryptor
239    }
240
241    fn remove_frame_cryptor(&self, participant_identity: ParticipantIdentity, track_sid: TrackSid) {
242        log::debug!("removing frame cryptor for {}", participant_identity);
243
244        let mut inner = self.inner.lock();
245        inner.frame_cryptors.remove(&(participant_identity, track_sid));
246    }
247
248    /// Decrypt data received from a data channel
249    pub fn handle_encrypted_data(
250        &self,
251        data: &[u8],
252        iv: &[u8],
253        participant_identity: &str,
254        key_index: u32,
255    ) -> Option<Vec<u8>> {
256        let inner = self.inner.lock();
257
258        let data_packet_cryptor = inner.data_packet_cryptor.as_ref()?;
259
260        let encrypted_packet = EncryptedPacket { data: data.to_vec(), iv: iv.to_vec(), key_index };
261
262        match data_packet_cryptor.decrypt(participant_identity, &encrypted_packet) {
263            Ok(decrypted_data) => Some(decrypted_data),
264            Err(e) => {
265                log::warn!("handle_encrypted_data error: {}", e);
266                None
267            }
268        }
269    }
270
271    /// Encrypt data for transmission over a data channel
272    pub fn encrypt_data(
273        &self,
274        data: &[u8],
275        participant_identity: &str,
276        key_index: u32,
277    ) -> Result<EncryptedPacket, Box<dyn std::error::Error>> {
278        let inner = self.inner.lock();
279
280        let data_packet_cryptor =
281            inner.data_packet_cryptor.as_ref().ok_or("DataPacketCryptor is not initialized")?;
282
283        data_packet_cryptor.encrypt(participant_identity, key_index, data)
284    }
285}