annis_web/
state.rs

1use crate::auth::LoginInfo;
2use crate::{config::CliConfig, errors::AppError, Result, TEMPLATES_DIR};
3use axum::{async_trait, extract::FromRequestParts, http::request::Parts};
4use chrono::Utc;
5use dashmap::DashMap;
6use minijinja::Value;
7use oauth2::{basic::BasicClient, PkceCodeVerifier};
8use serde::{Deserialize, Serialize};
9use std::{collections::BTreeSet, sync::Arc};
10use tempfile::NamedTempFile;
11use time::OffsetDateTime;
12use tokio::{sync::mpsc::Receiver, task::JoinHandle};
13use url::Url;
14
15#[derive(Serialize, Deserialize, Debug, Clone, Default)]
16pub struct Session {
17    selected_corpora: BTreeSet<String>,
18    #[serde(skip)]
19    session: tower_sessions::Session,
20    session_id: String,
21}
22
23impl Session {
24    pub const SELECTED_CORPORA_KEY: &str = "selected_corpora";
25
26    fn update_session(
27        session: &tower_sessions::Session,
28        selected_corpora: &BTreeSet<String>,
29    ) -> Result<()> {
30        session.insert(Self::SELECTED_CORPORA_KEY, selected_corpora.clone())?;
31        Ok(())
32    }
33
34    pub fn set_selected_corpora(&mut self, selected_corpora: BTreeSet<String>) -> Result<()> {
35        self.selected_corpora = selected_corpora;
36        Self::update_session(&self.session, &self.selected_corpora)?;
37        Ok(())
38    }
39
40    pub fn selected_corpora(&self) -> &BTreeSet<String> {
41        &self.selected_corpora
42    }
43
44    pub fn id(&self) -> &str {
45        &self.session_id
46    }
47
48    pub fn expiration_time(&self) -> Option<OffsetDateTime> {
49        self.session.expiration_time()
50    }
51}
52
53#[async_trait]
54impl<S> FromRequestParts<S> for Session
55where
56    S: Send + Sync,
57{
58    type Rejection = AppError;
59
60    async fn from_request_parts(
61        req: &mut Parts,
62        state: &S,
63    ) -> std::result::Result<Self, Self::Rejection> {
64        let session = tower_sessions::Session::from_request_parts(req, state).await?;
65        let selected_corpora: BTreeSet<String> = session
66            .get(Session::SELECTED_CORPORA_KEY)?
67            .unwrap_or_default();
68
69        Self::update_session(&session, &selected_corpora)?;
70
71        Ok(Self {
72            session_id: session.id().to_string(),
73            session,
74            selected_corpora,
75        })
76    }
77}
78
79#[derive(Debug)]
80pub struct ExportJob {
81    pub handle: JoinHandle<Result<NamedTempFile>>,
82    progress: f32,
83    progress_receiver: Receiver<f32>,
84}
85
86impl ExportJob {
87    pub fn new(
88        handle: JoinHandle<Result<NamedTempFile>>,
89        progress_receiver: Receiver<f32>,
90    ) -> ExportJob {
91        ExportJob {
92            handle,
93            progress_receiver,
94            progress: 0.0,
95        }
96    }
97
98    pub fn get_progress(&mut self) -> f32 {
99        while let Ok(new_progress) = self.progress_receiver.try_recv() {
100            self.progress = new_progress;
101        }
102        self.progress
103    }
104}
105
106#[derive(Clone)]
107pub enum SessionArg {
108    Session(Session),
109    Id(String),
110}
111
112impl SessionArg {
113    pub fn id(&self) -> String {
114        match self {
115            SessionArg::Session(s) => s.id().to_string(),
116            SessionArg::Id(id) => id.to_string(),
117        }
118    }
119}
120
121pub struct GlobalAppState {
122    pub service_url: Url,
123    pub templates: minijinja::Environment<'static>,
124    pub oauth2_client: Option<BasicClient>,
125    pub background_jobs: DashMap<String, ExportJob>,
126    pub auth_requests: DashMap<String, PkceCodeVerifier>,
127    pub login_info: Arc<DashMap<String, LoginInfo>>,
128    default_client: reqwest::Client,
129}
130
131impl GlobalAppState {
132    pub fn new(config: &CliConfig) -> Result<Self> {
133        let oauth2_client = config.create_oauth2_basic_client()?;
134
135        let mut templates = minijinja::Environment::new();
136
137        // Define any global variables
138        templates.add_global("url_prefix", config.frontend_prefix.to_string());
139        templates.add_global("login_configured", oauth2_client.is_some());
140
141        // Load templates by name from the included templates folder
142        templates.set_loader(|name| {
143            if let Some(file) = TEMPLATES_DIR.get_file(name) {
144                Ok(file.contents_utf8().map(|s| s.to_string()))
145            } else {
146                Ok(None)
147            }
148        });
149
150        let login_info: DashMap<String, LoginInfo> = DashMap::new();
151        let login_info = Arc::new(login_info);
152
153        // Add a function for the template that allows to easily extract the username
154        let login_info_for_template = login_info.clone();
155        templates.add_function("username", move |session: Value| -> Value {
156            if let Ok(session_id) = session.get_attr("session_id") {
157                if let Some(l) = login_info_for_template.get(&session_id.to_string()) {
158                    if let Ok(Some(username)) = l.user_id() {
159                        return Value::from(username);
160                    }
161                }
162            }
163            Value::UNDEFINED
164        });
165
166        let service_url = if config.service_url.is_empty() {
167            Url::parse("http://127.0.0.1:5711")?
168        } else {
169            Url::parse(&config.service_url)?
170        };
171        let default_client = reqwest::ClientBuilder::new().build()?;
172        let result = Self {
173            service_url,
174            background_jobs: DashMap::new(),
175            templates,
176            auth_requests: DashMap::new(),
177            login_info,
178            oauth2_client,
179            default_client,
180        };
181        Ok(result)
182    }
183
184    pub fn create_client(&self, session: &SessionArg) -> Result<reqwest::Client> {
185        if let SessionArg::Session(session) = session {
186            // Mark this login info as accessed, so we know it is not stale and should not be removed
187            self.login_info
188                .alter(&session.id().to_string(), |_, mut l| {
189                    if let (Some(old_expiry), Some(new_expiry)) =
190                        (l.expires_unix(), session.expiration_time())
191                    {
192                        // Check if the new expiration date is actually longer before replacing it
193                        if old_expiry < new_expiry.unix_timestamp() {
194                            l.set_expiration_unix(Some(new_expiry.unix_timestamp()));
195                        }
196                    } else {
197                        // Use the new expiration date
198                        l.set_expiration_unix(
199                            session.expiration_time().map(|t| t.unix_timestamp()),
200                        );
201                    }
202                    l
203                });
204        }
205
206        if let Some(login) = &self.login_info.get(&session.id()) {
207            // Return the authentifacted client
208            Ok(login.get_client())
209        } else {
210            // Fallback to the default client
211            Ok(self.default_client.clone())
212        }
213    }
214
215    /// Cleans up ressources coupled to sessions that are expired or non-existing.
216    pub async fn cleanup(&self) {
217        self.login_info.retain(|_session_id, login_info| {
218            if let Some(expiry) = login_info.expires_unix() {
219                Utc::now().timestamp() < expiry
220            } else {
221                true
222            }
223        });
224    }
225}
226
227#[cfg(test)]
228mod tests {
229
230    use crate::config::CliConfig;
231
232    use super::*;
233
234    use oauth2::{basic::BasicTokenType, AccessToken, StandardTokenResponse};
235
236    #[test]
237    fn client_access_time_updated_existing() {
238        let config = CliConfig::default();
239        let state = GlobalAppState::new(&config).unwrap();
240
241        // Create a session that should be updated when accessed
242        let now = OffsetDateTime::now_utc();
243
244        // The user session will only expire in 1 day
245        let session_expiration = now.checked_add(time::Duration::days(1)).unwrap();
246        let raw_session = tower_sessions::Session::new(Some(session_expiration));
247        let session_id = raw_session.id().to_string();
248
249        let mut session = Session::default();
250        session.session_id = session_id.clone();
251        session.session = raw_session;
252
253        let access_token = AccessToken::new("ABC".into());
254        let token_response = StandardTokenResponse::new(
255            access_token,
256            BasicTokenType::Bearer,
257            oauth2::EmptyExtraTokenFields {},
258        );
259        // Simulate an old access to the login info, which would trigger a cleanup
260        let expired_login_info =
261            LoginInfo::from_token(token_response, Some(now.unix_timestamp() - 1)).unwrap();
262
263        state
264            .login_info
265            .insert(session.session_id.clone(), expired_login_info.clone());
266
267        let session_arg = SessionArg::Session(session.clone());
268        state.create_client(&session_arg).unwrap();
269        // The login info expiration time must be updated to match the session
270        assert_eq!(
271            Some(session_expiration.unix_timestamp()),
272            state.login_info.get(&session_id).unwrap().expires_unix()
273        );
274    }
275
276    #[test]
277    fn client_access_time_updated_set_from_session() {
278        let config = CliConfig::default();
279        let state = GlobalAppState::new(&config).unwrap();
280
281        // Create a session that should be updated when accessed
282        let now = OffsetDateTime::now_utc();
283
284        // The user session will only expire in 1 day
285        let session_expiration = now.checked_add(time::Duration::days(1)).unwrap();
286        let raw_session = tower_sessions::Session::new(Some(session_expiration));
287        let session_id = raw_session.id().to_string();
288
289        let mut session = Session::default();
290        session.session_id = session_id.clone();
291        session.session = raw_session;
292
293        let access_token = AccessToken::new("ABC".into());
294        let token_response = StandardTokenResponse::new(
295            access_token,
296            BasicTokenType::Bearer,
297            oauth2::EmptyExtraTokenFields {},
298        );
299        // Simulate an old access to the login info, which does not have a expiration date
300        let expired_login_info = LoginInfo::from_token(token_response, None).unwrap();
301
302        state
303            .login_info
304            .insert(session.session_id.clone(), expired_login_info.clone());
305
306        let session_arg = SessionArg::Session(session.clone());
307        state.create_client(&session_arg).unwrap();
308        // The login info expiration time must be updated to match the session
309        assert_eq!(
310            Some(session_expiration.unix_timestamp()),
311            state.login_info.get(&session_id).unwrap().expires_unix()
312        );
313    }
314}