1use std::sync::{Arc, Mutex};
2use std::time::{Duration, Instant};
3
4use async_trait::async_trait;
5use ed25519_dalek::Signature;
6
7use crate::crypto::{identity::NodeCredentials, KeyExchange, SessionKeys, X25519KeyExchange};
8use crate::handshake::{
9 client::ClientHandshake, server::ServerHandshake, ChallengeAuthenticator, HandshakeContext,
10 HandshakeError, HandshakeOutcome, HandshakeParticipant, HandshakeTransport,
11};
12use crate::messages::{CapabilitySet, DeviceIdentity, SessionEstablished};
13
14pub mod state;
15use state::{SessionState, SessionStateError};
16
17impl From<SessionStateError> for HandshakeError {
18 fn from(err: SessionStateError) -> Self {
19 HandshakeError::Protocol(err.to_string())
20 }
21}
22
23#[derive(Debug, Clone, Copy, PartialEq)]
24pub enum AlnpRole {
25 Controller,
26 Node,
27}
28
29#[derive(Debug, Clone, Copy, PartialEq)]
30pub enum JitterStrategy {
31 HoldLast,
32 Drop,
33 Lerp,
34}
35
36#[derive(Debug, Clone)]
37pub struct AlnpSession {
38 pub role: AlnpRole,
39 state: Arc<Mutex<SessionState>>,
40 last_keepalive: Arc<Mutex<Instant>>,
41 jitter: Arc<Mutex<JitterStrategy>>,
42 streaming_enabled: Arc<Mutex<bool>>,
43 timeout: Duration,
44 session_established: Arc<Mutex<Option<SessionEstablished>>>,
45 session_keys: Arc<Mutex<Option<SessionKeys>>>,
46}
47
48impl AlnpSession {
49 pub fn new(role: AlnpRole) -> Self {
50 Self {
51 role,
52 state: Arc::new(Mutex::new(SessionState::Init)),
53 last_keepalive: Arc::new(Mutex::new(Instant::now())),
54 jitter: Arc::new(Mutex::new(JitterStrategy::HoldLast)),
55 streaming_enabled: Arc::new(Mutex::new(true)),
56 timeout: Duration::from_secs(10),
57 session_established: Arc::new(Mutex::new(None)),
58 session_keys: Arc::new(Mutex::new(None)),
59 }
60 }
61
62 pub fn established(&self) -> Option<SessionEstablished> {
63 self.session_established.lock().ok().and_then(|s| s.clone())
64 }
65
66 pub fn keys(&self) -> Option<SessionKeys> {
67 self.session_keys.lock().ok().and_then(|k| k.clone())
68 }
69
70 pub fn state(&self) -> SessionState {
71 self.state
72 .lock()
73 .map(|g| g.clone())
74 .unwrap_or(SessionState::Failed("state poisoned".to_string()))
75 }
76
77 pub fn ensure_streaming_ready(&self) -> Result<SessionEstablished, HandshakeError> {
78 let state = self.state();
79 match state {
80 SessionState::Ready { .. } | SessionState::Streaming { .. } => {
81 self.established().ok_or_else(|| {
82 HandshakeError::Authentication(
83 "session missing even though state is ready".into(),
84 )
85 })
86 }
87 SessionState::Failed(reason) => Err(HandshakeError::Authentication(reason)),
88 _ => Err(HandshakeError::Authentication(
89 "session not ready; streaming blocked".into(),
90 )),
91 }
92 }
93
94 pub fn update_keepalive(&self) {
95 if let Ok(mut k) = self.last_keepalive.lock() {
96 *k = Instant::now();
97 }
98 }
99
100 pub fn check_timeouts(&self) -> Result<(), HandshakeError> {
101 let now = Instant::now();
102 if let Ok(state) = self.state.lock() {
103 if state.check_timeout(self.timeout, now) {
104 self.fail("session timeout".into());
105 return Err(HandshakeError::Transport("session timeout".into()));
106 }
107 }
108 Ok(())
109 }
110
111 pub fn set_jitter_strategy(&self, strat: JitterStrategy) {
112 if let Ok(mut j) = self.jitter.lock() {
113 *j = strat;
114 }
115 }
116
117 pub fn jitter_strategy(&self) -> JitterStrategy {
118 self.jitter
119 .lock()
120 .map(|j| *j)
121 .unwrap_or(JitterStrategy::Drop)
122 }
123
124 pub fn close(&self) {
125 if let Ok(mut state) = self.state.lock() {
126 *state = SessionState::Closed;
127 }
128 }
129
130 pub fn fail(&self, reason: String) {
131 if let Ok(mut state) = self.state.lock() {
132 *state = SessionState::Failed(reason);
133 }
134 }
135
136 fn transition(&self, next: SessionState) -> Result<(), SessionStateError> {
137 let mut state = self.state.lock().unwrap();
138 let current = state.clone();
139 *state = current.transition(next)?;
140 Ok(())
141 }
142
143 pub fn set_streaming_enabled(&self, enabled: bool) {
144 if let Ok(mut flag) = self.streaming_enabled.lock() {
145 *flag = enabled;
146 }
147 }
148
149 pub fn mark_streaming(&self) {
150 if let Ok(mut state) = self.state.lock() {
151 let current = state.clone();
152 if let SessionState::Ready { .. } = current {
153 let _ = current
154 .transition(SessionState::Streaming {
155 since: Instant::now(),
156 })
157 .map(|next| *state = next);
158 }
159 }
160 }
161
162 pub fn streaming_enabled(&self) -> bool {
163 self.streaming_enabled.lock().map(|f| *f).unwrap_or(false)
164 }
165
166 fn apply_outcome(&self, outcome: HandshakeOutcome) {
167 if let Ok(mut guard) = self.session_established.lock() {
168 *guard = Some(outcome.established);
169 }
170 if let Ok(mut guard) = self.session_keys.lock() {
171 *guard = Some(outcome.keys);
172 }
173 }
174
175 pub async fn connect<T, A, K>(
176 identity: DeviceIdentity,
177 capabilities: CapabilitySet,
178 authenticator: A,
179 key_exchange: K,
180 context: HandshakeContext,
181 transport: &mut T,
182 ) -> Result<Self, HandshakeError>
183 where
184 T: HandshakeTransport + Send,
185 A: ChallengeAuthenticator + Send + Sync,
186 K: KeyExchange + Send + Sync,
187 {
188 let session = Self::new(AlnpRole::Controller);
189 session.transition(SessionState::Handshake)?;
190 let driver = ClientHandshake {
191 identity,
192 capabilities,
193 authenticator,
194 key_exchange,
195 context,
196 };
197
198 let outcome = driver.run(transport).await?;
199 session.transition(SessionState::Authenticated {
200 since: Instant::now(),
201 })?;
202 session.transition(SessionState::Ready {
203 since: Instant::now(),
204 })?;
205 session.apply_outcome(outcome);
206 Ok(session)
207 }
208
209 pub async fn accept<T, A, K>(
210 identity: DeviceIdentity,
211 capabilities: CapabilitySet,
212 authenticator: A,
213 key_exchange: K,
214 context: HandshakeContext,
215 transport: &mut T,
216 ) -> Result<Self, HandshakeError>
217 where
218 T: HandshakeTransport + Send,
219 A: ChallengeAuthenticator + Send + Sync,
220 K: KeyExchange + Send + Sync,
221 {
222 let session = Self::new(AlnpRole::Node);
223 session.transition(SessionState::Handshake)?;
224 let driver = ServerHandshake {
225 identity,
226 capabilities,
227 authenticator,
228 key_exchange,
229 context,
230 };
231
232 let outcome = driver.run(transport).await?;
233 session.transition(SessionState::Authenticated {
234 since: Instant::now(),
235 })?;
236 session.transition(SessionState::Ready {
237 since: Instant::now(),
238 })?;
239 session.apply_outcome(outcome);
240 Ok(session)
241 }
242}
243
244pub struct StaticKeyAuthenticator {
246 secret: Vec<u8>,
247}
248
249impl StaticKeyAuthenticator {
250 pub fn new(secret: Vec<u8>) -> Self {
251 Self { secret }
252 }
253}
254
255impl Default for StaticKeyAuthenticator {
256 fn default() -> Self {
257 Self::new(b"default-alnp-secret".to_vec())
258 }
259}
260
261impl ChallengeAuthenticator for StaticKeyAuthenticator {
262 fn sign_challenge(&self, nonce: &[u8]) -> Vec<u8> {
263 let mut sig = Vec::with_capacity(self.secret.len() + nonce.len());
264 sig.extend_from_slice(&self.secret);
265 sig.extend_from_slice(nonce);
266 sig
267 }
268
269 fn verify_challenge(&self, nonce: &[u8], signature: &[u8]) -> bool {
270 signature.ends_with(nonce) && signature.starts_with(&self.secret)
271 }
272}
273
274pub struct Ed25519Authenticator {
276 creds: NodeCredentials,
277}
278
279impl Ed25519Authenticator {
280 pub fn new(creds: NodeCredentials) -> Self {
281 Self { creds }
282 }
283}
284
285impl ChallengeAuthenticator for Ed25519Authenticator {
286 fn sign_challenge(&self, nonce: &[u8]) -> Vec<u8> {
287 self.creds.sign(nonce).to_vec()
288 }
289
290 fn verify_challenge(&self, nonce: &[u8], signature: &[u8]) -> bool {
291 if let Ok(sig) = Signature::from_slice(signature) {
292 self.creds.verify(nonce, &sig)
293 } else {
294 false
295 }
296 }
297}
298
299pub struct LoopbackTransport {
301 inbox: Vec<crate::handshake::HandshakeMessage>,
302}
303
304impl LoopbackTransport {
305 pub fn new() -> Self {
306 Self { inbox: Vec::new() }
307 }
308}
309
310#[async_trait]
311impl HandshakeTransport for LoopbackTransport {
312 async fn send(
313 &mut self,
314 msg: crate::handshake::HandshakeMessage,
315 ) -> Result<(), HandshakeError> {
316 self.inbox.push(msg);
317 Ok(())
318 }
319
320 async fn recv(&mut self) -> Result<crate::handshake::HandshakeMessage, HandshakeError> {
321 if self.inbox.is_empty() {
322 return Err(HandshakeError::Transport("loopback queue empty".into()));
323 }
324 Ok(self.inbox.remove(0))
325 }
326}
327
328pub async fn example_controller_session<T: HandshakeTransport + Send>(
330 identity: DeviceIdentity,
331 transport: &mut T,
332) -> Result<AlnpSession, HandshakeError> {
333 AlnpSession::connect(
334 identity,
335 CapabilitySet::default(),
336 StaticKeyAuthenticator::default(),
337 X25519KeyExchange::new(),
338 HandshakeContext::default(),
339 transport,
340 )
341 .await
342}
343
344pub async fn example_node_session<T: HandshakeTransport + Send>(
346 identity: DeviceIdentity,
347 transport: &mut T,
348) -> Result<AlnpSession, HandshakeError> {
349 AlnpSession::accept(
350 identity,
351 CapabilitySet::default(),
352 StaticKeyAuthenticator::default(),
353 X25519KeyExchange::new(),
354 HandshakeContext::default(),
355 transport,
356 )
357 .await
358}