1use std::sync::{Arc, Mutex};
2use std::time::{Duration, Instant};
3
4use async_trait::async_trait;
5use ed25519_dalek::Signature;
6
7use crate::crypto::{identity::NodeCredentials, KeyExchange, SessionKeys, X25519KeyExchange};
8use crate::handshake::{
9 client::ClientHandshake, server::ServerHandshake, ChallengeAuthenticator, HandshakeContext,
10 HandshakeError, HandshakeOutcome, HandshakeParticipant, HandshakeTransport,
11};
12use crate::messages::{CapabilitySet, DeviceIdentity, SessionEstablished};
13use crate::profile::{CompiledStreamProfile, StreamProfile};
14
15pub mod state;
16use state::{SessionState, SessionStateError};
17
18impl From<SessionStateError> for HandshakeError {
19 fn from(err: SessionStateError) -> Self {
20 HandshakeError::Protocol(err.to_string())
21 }
22}
23
24#[derive(Debug, Clone, Copy, PartialEq)]
25pub enum AlnpRole {
26 Controller,
27 Node,
28}
29
30#[derive(Debug, Clone, Copy, PartialEq)]
31pub enum JitterStrategy {
32 HoldLast,
33 Drop,
34 Lerp,
35}
36
37#[derive(Debug, Clone)]
38pub struct AlnpSession {
39 pub role: AlnpRole,
40 state: Arc<Mutex<SessionState>>,
41 last_keepalive: Arc<Mutex<Instant>>,
42 jitter: Arc<Mutex<JitterStrategy>>,
43 streaming_enabled: Arc<Mutex<bool>>,
44 timeout: Duration,
45 session_established: Arc<Mutex<Option<SessionEstablished>>>,
46 session_keys: Arc<Mutex<Option<SessionKeys>>>,
47 compiled_profile: Arc<Mutex<Option<CompiledStreamProfile>>>,
48 profile_locked: Arc<Mutex<bool>>,
49}
50
51impl AlnpSession {
52 pub fn new(role: AlnpRole) -> Self {
53 Self {
54 role,
55 state: Arc::new(Mutex::new(SessionState::Init)),
56 last_keepalive: Arc::new(Mutex::new(Instant::now())),
57 jitter: Arc::new(Mutex::new(JitterStrategy::HoldLast)),
58 streaming_enabled: Arc::new(Mutex::new(true)),
59 timeout: Duration::from_secs(10),
60 session_established: Arc::new(Mutex::new(None)),
61 session_keys: Arc::new(Mutex::new(None)),
62 compiled_profile: Arc::new(Mutex::new(None)),
63 profile_locked: Arc::new(Mutex::new(false)),
64 }
65 }
66
67 pub fn established(&self) -> Option<SessionEstablished> {
68 self.session_established.lock().ok().and_then(|s| s.clone())
69 }
70
71 pub fn keys(&self) -> Option<SessionKeys> {
72 self.session_keys.lock().ok().and_then(|k| k.clone())
73 }
74
75 pub fn state(&self) -> SessionState {
76 self.state
77 .lock()
78 .map(|g| g.clone())
79 .unwrap_or(SessionState::Failed("state poisoned".to_string()))
80 }
81
82 pub fn ensure_streaming_ready(&self) -> Result<SessionEstablished, HandshakeError> {
83 let state = self.state();
84 match state {
85 SessionState::Ready { .. } | SessionState::Streaming { .. } => {
86 self.established().ok_or_else(|| {
87 HandshakeError::Authentication(
88 "session missing even though state is ready".into(),
89 )
90 })
91 }
92 SessionState::Failed(reason) => Err(HandshakeError::Authentication(reason)),
93 _ => Err(HandshakeError::Authentication(
94 "session not ready; streaming blocked".into(),
95 )),
96 }
97 }
98
99 pub fn update_keepalive(&self) {
100 if let Ok(mut k) = self.last_keepalive.lock() {
101 *k = Instant::now();
102 }
103 }
104
105 pub fn check_timeouts(&self) -> Result<(), HandshakeError> {
106 let now = Instant::now();
107 if let Ok(state) = self.state.lock() {
108 if state.check_timeout(self.timeout, now) {
109 self.fail("session timeout".into());
110 return Err(HandshakeError::Transport("session timeout".into()));
111 }
112 }
113 Ok(())
114 }
115
116 pub fn set_stream_profile(
120 &self,
121 profile: CompiledStreamProfile,
122 ) -> Result<(), HandshakeError> {
123 let locked = self
124 .profile_locked
125 .lock()
126 .map_err(|_| HandshakeError::Protocol("profile lock poisoned".into()))?;
127 if *locked {
128 return Err(HandshakeError::Protocol(
129 "stream profile cannot be changed after streaming starts".into(),
130 ));
131 }
132 let mut compiled = self
133 .compiled_profile
134 .lock()
135 .map_err(|_| HandshakeError::Protocol("compiled profile lock poisoned".into()))?;
136 *compiled = Some(profile);
137 Ok(())
138 }
139
140 #[must_use]
144 pub fn profile_config_id(&self) -> Option<String> {
145 self.compiled_profile
146 .lock()
147 .ok()
148 .and_then(|guard| guard.clone().map(|profile| profile.config_id().to_string()))
149 }
150
151 #[must_use]
155 pub fn compiled_profile(&self) -> Option<CompiledStreamProfile> {
156 self.compiled_profile
157 .lock()
158 .ok()
159 .and_then(|guard| guard.clone())
160 }
161
162 #[cfg(test)]
163 pub(crate) fn set_locked_profile_for_testing(&self, profile: CompiledStreamProfile) {
164 let mut compiled = self.compiled_profile.lock().unwrap();
165 *compiled = Some(profile);
166 *self.profile_locked.lock().unwrap() = true;
167 }
168
169 pub fn set_jitter_strategy(&self, strat: JitterStrategy) {
170 if let Ok(mut j) = self.jitter.lock() {
171 *j = strat;
172 }
173 }
174
175 pub fn jitter_strategy(&self) -> JitterStrategy {
176 self.jitter
177 .lock()
178 .map(|j| *j)
179 .unwrap_or(JitterStrategy::Drop)
180 }
181
182 pub fn close(&self) {
183 if let Ok(mut state) = self.state.lock() {
184 *state = SessionState::Closed;
185 }
186 }
187
188 pub fn fail(&self, reason: String) {
189 if let Ok(mut state) = self.state.lock() {
190 *state = SessionState::Failed(reason);
191 }
192 }
193
194 fn transition(&self, next: SessionState) -> Result<(), SessionStateError> {
195 let mut state = self.state.lock().unwrap();
196 let current = state.clone();
197 *state = current.transition(next)?;
198 Ok(())
199 }
200
201 pub fn set_streaming_enabled(&self, enabled: bool) {
202 if let Ok(mut flag) = self.streaming_enabled.lock() {
203 *flag = enabled;
204 }
205 }
206
207 pub fn mark_streaming(&self) {
208 if let Ok(mut state) = self.state.lock() {
209 let current = state.clone();
210 if let SessionState::Ready { .. } = current {
211 let _ = current
212 .transition(SessionState::Streaming {
213 since: Instant::now(),
214 })
215 .map(|next| *state = next);
216 }
217 }
218 if let Ok(mut locked) = self.profile_locked.lock() {
219 *locked = true;
220 }
221 }
222
223 pub fn streaming_enabled(&self) -> bool {
224 self.streaming_enabled.lock().map(|f| *f).unwrap_or(false)
225 }
226
227 fn apply_outcome(&self, outcome: HandshakeOutcome) {
228 if let Ok(mut guard) = self.session_established.lock() {
229 *guard = Some(outcome.established);
230 }
231 if let Ok(mut guard) = self.session_keys.lock() {
232 *guard = Some(outcome.keys);
233 }
234 }
235
236 pub async fn connect<T, A, K>(
237 identity: DeviceIdentity,
238 capabilities: CapabilitySet,
239 authenticator: A,
240 key_exchange: K,
241 context: HandshakeContext,
242 transport: &mut T,
243 ) -> Result<Self, HandshakeError>
244 where
245 T: HandshakeTransport + Send,
246 A: ChallengeAuthenticator + Send + Sync,
247 K: KeyExchange + Send + Sync,
248 {
249 let session = Self::new(AlnpRole::Controller);
250 session.transition(SessionState::Handshake)?;
251 let driver = ClientHandshake {
252 identity,
253 capabilities,
254 authenticator,
255 key_exchange,
256 context,
257 };
258
259 let outcome = driver.run(transport).await?;
260 session.transition(SessionState::Authenticated {
261 since: Instant::now(),
262 })?;
263 session.transition(SessionState::Ready {
264 since: Instant::now(),
265 })?;
266 session.apply_outcome(outcome);
267 Ok(session)
268 }
269
270 pub async fn accept<T, A, K>(
271 identity: DeviceIdentity,
272 capabilities: CapabilitySet,
273 authenticator: A,
274 key_exchange: K,
275 context: HandshakeContext,
276 transport: &mut T,
277 ) -> Result<Self, HandshakeError>
278 where
279 T: HandshakeTransport + Send,
280 A: ChallengeAuthenticator + Send + Sync,
281 K: KeyExchange + Send + Sync,
282 {
283 let session = Self::new(AlnpRole::Node);
284 session.transition(SessionState::Handshake)?;
285 let driver = ServerHandshake {
286 identity,
287 capabilities,
288 authenticator,
289 key_exchange,
290 context,
291 };
292
293 let outcome = driver.run(transport).await?;
294 session.transition(SessionState::Authenticated {
295 since: Instant::now(),
296 })?;
297 session.transition(SessionState::Ready {
298 since: Instant::now(),
299 })?;
300 session.apply_outcome(outcome);
301 Ok(session)
302 }
303}
304
305pub struct StaticKeyAuthenticator {
307 secret: Vec<u8>,
308}
309
310impl StaticKeyAuthenticator {
311 pub fn new(secret: Vec<u8>) -> Self {
312 Self { secret }
313 }
314}
315
316impl Default for StaticKeyAuthenticator {
317 fn default() -> Self {
318 Self::new(b"default-alnp-secret".to_vec())
319 }
320}
321
322impl ChallengeAuthenticator for StaticKeyAuthenticator {
323 fn sign_challenge(&self, nonce: &[u8]) -> Vec<u8> {
324 let mut sig = Vec::with_capacity(self.secret.len() + nonce.len());
325 sig.extend_from_slice(&self.secret);
326 sig.extend_from_slice(nonce);
327 sig
328 }
329
330 fn verify_challenge(&self, nonce: &[u8], signature: &[u8]) -> bool {
331 signature.ends_with(nonce) && signature.starts_with(&self.secret)
332 }
333}
334
335pub struct Ed25519Authenticator {
337 creds: NodeCredentials,
338}
339
340impl Ed25519Authenticator {
341 pub fn new(creds: NodeCredentials) -> Self {
342 Self { creds }
343 }
344}
345
346impl ChallengeAuthenticator for Ed25519Authenticator {
347 fn sign_challenge(&self, nonce: &[u8]) -> Vec<u8> {
348 self.creds.sign(nonce).to_vec()
349 }
350
351 fn verify_challenge(&self, nonce: &[u8], signature: &[u8]) -> bool {
352 if let Ok(sig) = Signature::from_slice(signature) {
353 self.creds.verify(nonce, &sig)
354 } else {
355 false
356 }
357 }
358}
359
360pub struct LoopbackTransport {
362 inbox: Vec<crate::handshake::HandshakeMessage>,
363}
364
365impl LoopbackTransport {
366 pub fn new() -> Self {
367 Self { inbox: Vec::new() }
368 }
369}
370
371#[cfg(test)]
372mod session_tests {
373 use super::*;
374
375 #[test]
376 fn profile_lock_prevents_profile_swaps() {
377 let session = AlnpSession::new(AlnpRole::Controller);
378 let compiled = StreamProfile::auto().compile().unwrap();
379 session.set_stream_profile(compiled.clone()).unwrap();
380 session.mark_streaming();
381 assert!(session.set_stream_profile(compiled).is_err());
382 }
383
384 #[test]
385 fn config_id_matches_profile() {
386 let session = AlnpSession::new(AlnpRole::Controller);
387 let compiled = StreamProfile::realtime().compile().unwrap();
388 session.set_stream_profile(compiled.clone()).unwrap();
389 assert_eq!(
390 session.profile_config_id().unwrap(),
391 compiled.config_id()
392 );
393 }
394}
395
396#[async_trait]
397impl HandshakeTransport for LoopbackTransport {
398 async fn send(
399 &mut self,
400 msg: crate::handshake::HandshakeMessage,
401 ) -> Result<(), HandshakeError> {
402 self.inbox.push(msg);
403 Ok(())
404 }
405
406 async fn recv(&mut self) -> Result<crate::handshake::HandshakeMessage, HandshakeError> {
407 if self.inbox.is_empty() {
408 return Err(HandshakeError::Transport("loopback queue empty".into()));
409 }
410 Ok(self.inbox.remove(0))
411 }
412}
413
414pub async fn example_controller_session<T: HandshakeTransport + Send>(
416 identity: DeviceIdentity,
417 transport: &mut T,
418) -> Result<AlnpSession, HandshakeError> {
419 AlnpSession::connect(
420 identity,
421 CapabilitySet::default(),
422 StaticKeyAuthenticator::default(),
423 X25519KeyExchange::new(),
424 HandshakeContext::default(),
425 transport,
426 )
427 .await
428}
429
430pub async fn example_node_session<T: HandshakeTransport + Send>(
432 identity: DeviceIdentity,
433 transport: &mut T,
434) -> Result<AlnpSession, HandshakeError> {
435 AlnpSession::accept(
436 identity,
437 CapabilitySet::default(),
438 StaticKeyAuthenticator::default(),
439 X25519KeyExchange::new(),
440 HandshakeContext::default(),
441 transport,
442 )
443 .await
444}