1use std::sync::{Arc, RwLock};
9
10use reqwest::cookie::{CookieStore, Jar};
11use serde::Serialize;
12use serde::de::DeserializeOwned;
13use tracing::{debug, trace};
14use url::Url;
15
16use crate::auth::ControllerPlatform;
17use crate::error::Error;
18use crate::session::models::SessionResponse;
19use crate::transport::TransportConfig;
20
21#[derive(serde::Deserialize)]
23struct UnifiOsError {
24 error: Option<UnifiOsErrorInner>,
25}
26
27#[derive(serde::Deserialize)]
28struct UnifiOsErrorInner {
29 code: u16,
30 message: Option<String>,
31}
32
33#[derive(Debug, Clone, Copy, PartialEq, Eq)]
35pub enum SessionAuth {
36 Cookie,
38 ApiKey,
41}
42
43pub struct SessionClient {
50 http: reqwest::Client,
51 base_url: Url,
52 site: String,
53 platform: ControllerPlatform,
54 auth: SessionAuth,
55 csrf_token: RwLock<Option<String>>,
59 cookie_jar: Option<Arc<Jar>>,
61}
62
63impl SessionClient {
64 pub fn new(
71 base_url: Url,
72 site: String,
73 platform: ControllerPlatform,
74 transport: &TransportConfig,
75 ) -> Result<Self, Error> {
76 let config = if transport.cookie_jar.is_some() {
77 transport.clone()
78 } else {
79 transport.clone().with_cookie_jar()
80 };
81 let cookie_jar = config.cookie_jar.clone();
82 let http = config.build_client()?;
83 Ok(Self {
84 http,
85 base_url,
86 site,
87 platform,
88 auth: SessionAuth::Cookie,
89 csrf_token: RwLock::new(None),
90 cookie_jar,
91 })
92 }
93
94 pub fn with_client(
99 http: reqwest::Client,
100 base_url: Url,
101 site: String,
102 platform: ControllerPlatform,
103 auth: SessionAuth,
104 ) -> Self {
105 Self {
106 http,
107 base_url,
108 site,
109 platform,
110 auth,
111 csrf_token: RwLock::new(None),
112 cookie_jar: None,
113 }
114 }
115
116 pub fn auth(&self) -> SessionAuth {
118 self.auth
119 }
120
121 pub fn site(&self) -> &str {
123 &self.site
124 }
125
126 pub fn http(&self) -> &reqwest::Client {
128 &self.http
129 }
130
131 pub fn base_url(&self) -> &Url {
133 &self.base_url
134 }
135
136 pub fn platform(&self) -> ControllerPlatform {
138 self.platform
139 }
140
141 pub fn cookie_header(&self) -> Option<String> {
146 let jar = self.cookie_jar.as_ref()?;
147 let cookies = jar.cookies(&self.base_url)?;
148 cookies.to_str().ok().map(String::from)
149 }
150
151 pub(crate) fn add_cookie(&self, set_cookie_value: &str, url: &Url) -> Result<(), Error> {
158 let jar = self
159 .cookie_jar
160 .as_ref()
161 .ok_or_else(|| Error::Authentication {
162 message: "no cookie jar available for MFA flow".into(),
163 })?;
164 let header_value: reqwest::header::HeaderValue =
165 set_cookie_value
166 .parse()
167 .map_err(|_| Error::Authentication {
168 message: "failed to parse MFA cookie value".into(),
169 })?;
170 jar.set_cookies(&mut std::iter::once(&header_value), url);
171 Ok(())
172 }
173
174 pub(crate) fn csrf_token_value(&self) -> Option<String> {
178 self.csrf_token.read().expect("CSRF lock poisoned").clone()
179 }
180
181 pub(crate) fn set_csrf_token(&self, token: String) {
183 debug!("storing CSRF token");
184 *self.csrf_token.write().expect("CSRF lock poisoned") = Some(token);
185 }
186
187 fn update_csrf_from_response(&self, headers: &reqwest::header::HeaderMap) {
189 let new_token = headers
191 .get("X-Updated-CSRF-Token")
192 .or_else(|| headers.get("x-csrf-token"))
193 .and_then(|v| v.to_str().ok())
194 .map(String::from);
195
196 if let Some(token) = new_token {
197 trace!("CSRF token rotated");
198 *self.csrf_token.write().expect("CSRF lock poisoned") = Some(token);
199 }
200 }
201
202 fn apply_csrf(&self, builder: reqwest::RequestBuilder) -> reqwest::RequestBuilder {
204 let guard = self.csrf_token.read().expect("CSRF lock poisoned");
205 match guard.as_deref() {
206 Some(token) => builder.header("X-CSRF-Token", token),
207 None => builder,
208 }
209 }
210
211 fn unauthorized_error(&self) -> Error {
216 match self.auth {
217 SessionAuth::Cookie => Error::SessionExpired,
218 SessionAuth::ApiKey => Error::InvalidApiKey,
219 }
220 }
221
222 pub(crate) fn api_url(&self, path: &str) -> Url {
229 let prefix = self.platform.session_prefix().unwrap_or("");
230 let base = self.base_url.as_str().trim_end_matches('/');
231 let prefix = prefix.trim_end_matches('/');
232 let full = format!("{base}{prefix}/api/{path}");
233 Url::parse(&full).expect("invalid API URL")
234 }
235
236 pub(crate) fn site_url(&self, path: &str) -> Url {
240 let prefix = self.platform.session_prefix().unwrap_or("");
241 let base = self.base_url.as_str().trim_end_matches('/');
242 let prefix = prefix.trim_end_matches('/');
243 let full = format!("{base}{prefix}/api/s/{}/{path}", self.site);
244 Url::parse(&full).expect("invalid site URL")
245 }
246
247 pub(crate) fn site_url_v2(&self, path: &str) -> Url {
252 let prefix = self.platform.session_prefix().unwrap_or("");
253 let base = self.base_url.as_str().trim_end_matches('/');
254 let prefix = prefix.trim_end_matches('/');
255 let full = format!("{base}{prefix}/v2/api/site/{}/{path}", self.site);
256 Url::parse(&full).expect("invalid v2 site URL")
257 }
258
259 pub(crate) async fn get<T: DeserializeOwned>(&self, url: Url) -> Result<Vec<T>, Error> {
263 debug!("GET {}", url);
264
265 let resp = self.http.get(url).send().await.map_err(Error::Transport)?;
266
267 self.parse_envelope(resp).await
268 }
269
270 pub(crate) async fn get_raw(&self, url: Url) -> Result<serde_json::Value, Error> {
275 debug!("GET (raw) {}", url);
276
277 let resp = self.http.get(url).send().await.map_err(Error::Transport)?;
278 let status = resp.status();
279
280 if status == reqwest::StatusCode::UNAUTHORIZED {
281 return Err(self.unauthorized_error());
282 }
283 if !status.is_success() {
284 let body = resp.text().await.unwrap_or_default();
285 return Err(Error::SessionApi {
286 message: format!("HTTP {status}: {}", &body[..body.len().min(200)]),
287 });
288 }
289
290 let body = resp.text().await.map_err(Error::Transport)?;
291 serde_json::from_str(&body).map_err(|e| Error::Deserialization {
292 message: format!("{e}"),
293 body,
294 })
295 }
296
297 pub(crate) async fn post<T: DeserializeOwned>(
299 &self,
300 url: Url,
301 body: &(impl Serialize + Sync),
302 ) -> Result<Vec<T>, Error> {
303 debug!("POST {}", url);
304
305 let builder = self.apply_csrf(self.http.post(url).json(body));
306 let resp = builder.send().await.map_err(Error::Transport)?;
307
308 self.parse_envelope(resp).await
309 }
310
311 #[allow(dead_code)]
313 pub(crate) async fn put<T: DeserializeOwned>(
314 &self,
315 url: Url,
316 body: &(impl Serialize + Sync),
317 ) -> Result<Vec<T>, Error> {
318 debug!("PUT {}", url);
319
320 let builder = self.apply_csrf(self.http.put(url).json(body));
321 let resp = builder.send().await.map_err(Error::Transport)?;
322
323 self.parse_envelope(resp).await
324 }
325
326 #[allow(dead_code)]
328 pub(crate) async fn delete<T: DeserializeOwned>(&self, url: Url) -> Result<Vec<T>, Error> {
329 debug!("DELETE {}", url);
330
331 let builder = self.apply_csrf(self.http.delete(url));
332 let resp = builder.send().await.map_err(Error::Transport)?;
333
334 self.parse_envelope(resp).await
335 }
336
337 pub async fn raw_get(&self, path: &str) -> Result<serde_json::Value, Error> {
341 let prefix = self.platform.session_prefix().unwrap_or("");
342 let base = self.base_url.as_str().trim_end_matches('/');
343 let prefix = prefix.trim_end_matches('/');
344 let url = Url::parse(&format!("{base}{prefix}/{path}")).expect("invalid raw URL");
345 self.get_raw(url).await
346 }
347
348 pub async fn raw_post(
350 &self,
351 path: &str,
352 body: &serde_json::Value,
353 ) -> Result<serde_json::Value, Error> {
354 let prefix = self.platform.session_prefix().unwrap_or("");
355 let base = self.base_url.as_str().trim_end_matches('/');
356 let prefix = prefix.trim_end_matches('/');
357 let url = Url::parse(&format!("{base}{prefix}/{path}")).expect("invalid raw URL");
358 debug!("POST (raw) {}", url);
359
360 let builder = self.apply_csrf(self.http.post(url).json(body));
361 let resp = builder.send().await.map_err(Error::Transport)?;
362 let status = resp.status();
363
364 if status == reqwest::StatusCode::UNAUTHORIZED {
365 return Err(self.unauthorized_error());
366 }
367 if !status.is_success() {
368 let body = resp.text().await.unwrap_or_default();
369 return Err(Error::SessionApi {
370 message: format!("HTTP {status}: {}", &body[..body.len().min(200)]),
371 });
372 }
373
374 let body = resp.text().await.map_err(Error::Transport)?;
375 serde_json::from_str(&body).map_err(|e| Error::Deserialization {
376 message: format!("{e}"),
377 body,
378 })
379 }
380
381 pub async fn raw_put(
383 &self,
384 path: &str,
385 body: &serde_json::Value,
386 ) -> Result<serde_json::Value, Error> {
387 let prefix = self.platform.session_prefix().unwrap_or("");
388 let base = self.base_url.as_str().trim_end_matches('/');
389 let prefix = prefix.trim_end_matches('/');
390 let url = Url::parse(&format!("{base}{prefix}/{path}")).expect("invalid raw URL");
391 debug!("PUT (raw) {}", url);
392
393 let builder = self.apply_csrf(self.http.put(url).json(body));
394 let resp = builder.send().await.map_err(Error::Transport)?;
395 let status = resp.status();
396
397 if status == reqwest::StatusCode::UNAUTHORIZED {
398 return Err(self.unauthorized_error());
399 }
400 if !status.is_success() {
401 let body = resp.text().await.unwrap_or_default();
402 return Err(Error::SessionApi {
403 message: format!("HTTP {status}: {}", &body[..body.len().min(200)]),
404 });
405 }
406
407 let body = resp.text().await.map_err(Error::Transport)?;
408 serde_json::from_str(&body).map_err(|e| Error::Deserialization {
409 message: format!("{e}"),
410 body,
411 })
412 }
413
414 pub async fn raw_patch(
416 &self,
417 path: &str,
418 body: &serde_json::Value,
419 ) -> Result<serde_json::Value, Error> {
420 let prefix = self.platform.session_prefix().unwrap_or("");
421 let base = self.base_url.as_str().trim_end_matches('/');
422 let prefix = prefix.trim_end_matches('/');
423 let url = Url::parse(&format!("{base}{prefix}/{path}")).expect("invalid raw URL");
424 debug!("PATCH (raw) {}", url);
425
426 let builder = self.apply_csrf(self.http.patch(url).json(body));
427 let resp = builder.send().await.map_err(Error::Transport)?;
428 let status = resp.status();
429
430 if status == reqwest::StatusCode::UNAUTHORIZED {
431 return Err(self.unauthorized_error());
432 }
433 if !status.is_success() {
434 let body = resp.text().await.unwrap_or_default();
435 return Err(Error::SessionApi {
436 message: format!("HTTP {status}: {}", &body[..body.len().min(200)]),
437 });
438 }
439
440 let body = resp.text().await.map_err(Error::Transport)?;
441 serde_json::from_str(&body).map_err(|e| Error::Deserialization {
442 message: format!("{e}"),
443 body,
444 })
445 }
446
447 pub async fn raw_delete(&self, path: &str) -> Result<(), Error> {
449 let prefix = self.platform.session_prefix().unwrap_or("");
450 let base = self.base_url.as_str().trim_end_matches('/');
451 let prefix = prefix.trim_end_matches('/');
452 let url = Url::parse(&format!("{base}{prefix}/{path}")).expect("invalid raw URL");
453 debug!("DELETE (raw) {}", url);
454
455 let builder = self.apply_csrf(self.http.delete(url));
456 let resp = builder.send().await.map_err(Error::Transport)?;
457 let status = resp.status();
458
459 if status == reqwest::StatusCode::UNAUTHORIZED {
460 return Err(self.unauthorized_error());
461 }
462 if !status.is_success() {
463 let body = resp.text().await.unwrap_or_default();
464 return Err(Error::SessionApi {
465 message: format!("HTTP {status}: {}", &body[..body.len().min(200)]),
466 });
467 }
468
469 Ok(())
470 }
471
472 async fn parse_envelope<T: DeserializeOwned>(
478 &self,
479 resp: reqwest::Response,
480 ) -> Result<Vec<T>, Error> {
481 let status = resp.status();
482
483 self.update_csrf_from_response(resp.headers());
485
486 if status == reqwest::StatusCode::UNAUTHORIZED {
487 return Err(self.unauthorized_error());
488 }
489
490 if status == reqwest::StatusCode::FORBIDDEN {
491 return Err(Error::SessionApi {
492 message: "insufficient permissions (HTTP 403)".into(),
493 });
494 }
495
496 if !status.is_success() {
497 let body = resp.text().await.unwrap_or_default();
498 return Err(Error::SessionApi {
499 message: format!("HTTP {status}: {}", &body[..body.len().min(200)]),
500 });
501 }
502
503 let body = resp.text().await.map_err(Error::Transport)?;
504
505 if let Ok(wrapper) = serde_json::from_str::<UnifiOsError>(&body)
507 && let Some(err) = wrapper.error
508 {
509 let msg = err.message.unwrap_or_default();
510 return Err(if err.code == 401 {
511 if msg.is_empty() {
512 self.unauthorized_error()
513 } else {
514 match self.unauthorized_error() {
515 Error::SessionExpired => Error::Authentication {
516 message: format!("session expired: {msg}"),
517 },
518 Error::InvalidApiKey => Error::Authentication {
519 message: format!("API key rejected: {msg}"),
520 },
521 other => other,
522 }
523 }
524 } else {
525 Error::SessionApi {
526 message: format!("UniFi OS error {}: {msg}", err.code),
527 }
528 });
529 }
530
531 let envelope: SessionResponse<T> = serde_json::from_str(&body).map_err(|e| {
532 let preview = &body[..body.len().min(200)];
533 Error::Deserialization {
534 message: format!("{e} (body preview: {preview:?})"),
535 body: body.clone(),
536 }
537 })?;
538
539 match envelope.meta.rc.as_str() {
540 "ok" => Ok(envelope.data),
541 _ => Err(Error::SessionApi {
542 message: envelope
543 .meta
544 .msg
545 .unwrap_or_else(|| format!("rc={}", envelope.meta.rc)),
546 }),
547 }
548 }
549}
550
551#[cfg(test)]
552mod tests {
553 use url::Url;
554
555 use super::{SessionAuth, SessionClient};
556 use crate::{ControllerPlatform, Error};
557
558 fn client(auth: SessionAuth) -> SessionClient {
559 SessionClient::with_client(
560 reqwest::Client::new(),
561 Url::parse("https://controller.example").expect("valid test URL"),
562 "default".into(),
563 ControllerPlatform::ClassicController,
564 auth,
565 )
566 }
567
568 #[test]
569 fn unauthorized_cookie_client_reports_session_expired() {
570 assert!(matches!(
571 client(SessionAuth::Cookie).unauthorized_error(),
572 Error::SessionExpired
573 ));
574 }
575
576 #[test]
577 fn unauthorized_api_key_client_reports_invalid_api_key() {
578 assert!(matches!(
579 client(SessionAuth::ApiKey).unauthorized_error(),
580 Error::InvalidApiKey
581 ));
582 }
583}