1use crate::auth::LoginInfo;
2use crate::{config::CliConfig, errors::AppError, Result, TEMPLATES_DIR};
3use axum::{async_trait, extract::FromRequestParts, http::request::Parts};
4use chrono::Utc;
5use dashmap::DashMap;
6use minijinja::Value;
7use oauth2::{basic::BasicClient, PkceCodeVerifier};
8use serde::{Deserialize, Serialize};
9use std::{collections::BTreeSet, sync::Arc};
10use tempfile::NamedTempFile;
11use time::OffsetDateTime;
12use tokio::{sync::mpsc::Receiver, task::JoinHandle};
13use url::Url;
14
15#[derive(Serialize, Deserialize, Debug, Clone, Default)]
16pub struct Session {
17 selected_corpora: BTreeSet<String>,
18 #[serde(skip)]
19 session: tower_sessions::Session,
20 session_id: String,
21}
22
23impl Session {
24 pub const SELECTED_CORPORA_KEY: &str = "selected_corpora";
25
26 fn update_session(
27 session: &tower_sessions::Session,
28 selected_corpora: &BTreeSet<String>,
29 ) -> Result<()> {
30 session.insert(Self::SELECTED_CORPORA_KEY, selected_corpora.clone())?;
31 Ok(())
32 }
33
34 pub fn set_selected_corpora(&mut self, selected_corpora: BTreeSet<String>) -> Result<()> {
35 self.selected_corpora = selected_corpora;
36 Self::update_session(&self.session, &self.selected_corpora)?;
37 Ok(())
38 }
39
40 pub fn selected_corpora(&self) -> &BTreeSet<String> {
41 &self.selected_corpora
42 }
43
44 pub fn id(&self) -> &str {
45 &self.session_id
46 }
47
48 pub fn expiration_time(&self) -> Option<OffsetDateTime> {
49 self.session.expiration_time()
50 }
51}
52
53#[async_trait]
54impl<S> FromRequestParts<S> for Session
55where
56 S: Send + Sync,
57{
58 type Rejection = AppError;
59
60 async fn from_request_parts(
61 req: &mut Parts,
62 state: &S,
63 ) -> std::result::Result<Self, Self::Rejection> {
64 let session = tower_sessions::Session::from_request_parts(req, state).await?;
65 let selected_corpora: BTreeSet<String> = session
66 .get(Session::SELECTED_CORPORA_KEY)?
67 .unwrap_or_default();
68
69 Self::update_session(&session, &selected_corpora)?;
70
71 Ok(Self {
72 session_id: session.id().to_string(),
73 session,
74 selected_corpora,
75 })
76 }
77}
78
79#[derive(Debug)]
80pub struct ExportJob {
81 pub handle: JoinHandle<Result<NamedTempFile>>,
82 progress: f32,
83 progress_receiver: Receiver<f32>,
84}
85
86impl ExportJob {
87 pub fn new(
88 handle: JoinHandle<Result<NamedTempFile>>,
89 progress_receiver: Receiver<f32>,
90 ) -> ExportJob {
91 ExportJob {
92 handle,
93 progress_receiver,
94 progress: 0.0,
95 }
96 }
97
98 pub fn get_progress(&mut self) -> f32 {
99 while let Ok(new_progress) = self.progress_receiver.try_recv() {
100 self.progress = new_progress;
101 }
102 self.progress
103 }
104}
105
106#[derive(Clone)]
107pub enum SessionArg {
108 Session(Session),
109 Id(String),
110}
111
112impl SessionArg {
113 pub fn id(&self) -> String {
114 match self {
115 SessionArg::Session(s) => s.id().to_string(),
116 SessionArg::Id(id) => id.to_string(),
117 }
118 }
119}
120
121pub struct GlobalAppState {
122 pub service_url: Url,
123 pub templates: minijinja::Environment<'static>,
124 pub oauth2_client: Option<BasicClient>,
125 pub background_jobs: DashMap<String, ExportJob>,
126 pub auth_requests: DashMap<String, PkceCodeVerifier>,
127 pub login_info: Arc<DashMap<String, LoginInfo>>,
128 default_client: reqwest::Client,
129}
130
131impl GlobalAppState {
132 pub fn new(config: &CliConfig) -> Result<Self> {
133 let oauth2_client = config.create_oauth2_basic_client()?;
134
135 let mut templates = minijinja::Environment::new();
136
137 templates.add_global("url_prefix", config.frontend_prefix.to_string());
139 templates.add_global("login_configured", oauth2_client.is_some());
140
141 templates.set_loader(|name| {
143 if let Some(file) = TEMPLATES_DIR.get_file(name) {
144 Ok(file.contents_utf8().map(|s| s.to_string()))
145 } else {
146 Ok(None)
147 }
148 });
149
150 let login_info: DashMap<String, LoginInfo> = DashMap::new();
151 let login_info = Arc::new(login_info);
152
153 let login_info_for_template = login_info.clone();
155 templates.add_function("username", move |session: Value| -> Value {
156 if let Ok(session_id) = session.get_attr("session_id") {
157 if let Some(l) = login_info_for_template.get(&session_id.to_string()) {
158 if let Ok(Some(username)) = l.user_id() {
159 return Value::from(username);
160 }
161 }
162 }
163 Value::UNDEFINED
164 });
165
166 let service_url = if config.service_url.is_empty() {
167 Url::parse("http://127.0.0.1:5711")?
168 } else {
169 Url::parse(&config.service_url)?
170 };
171 let default_client = reqwest::ClientBuilder::new().build()?;
172 let result = Self {
173 service_url,
174 background_jobs: DashMap::new(),
175 templates,
176 auth_requests: DashMap::new(),
177 login_info,
178 oauth2_client,
179 default_client,
180 };
181 Ok(result)
182 }
183
184 pub fn create_client(&self, session: &SessionArg) -> Result<reqwest::Client> {
185 if let SessionArg::Session(session) = session {
186 self.login_info
188 .alter(&session.id().to_string(), |_, mut l| {
189 if let (Some(old_expiry), Some(new_expiry)) =
190 (l.expires_unix(), session.expiration_time())
191 {
192 if old_expiry < new_expiry.unix_timestamp() {
194 l.set_expiration_unix(Some(new_expiry.unix_timestamp()));
195 }
196 } else {
197 l.set_expiration_unix(
199 session.expiration_time().map(|t| t.unix_timestamp()),
200 );
201 }
202 l
203 });
204 }
205
206 if let Some(login) = &self.login_info.get(&session.id()) {
207 Ok(login.get_client())
209 } else {
210 Ok(self.default_client.clone())
212 }
213 }
214
215 pub async fn cleanup(&self) {
217 self.login_info.retain(|_session_id, login_info| {
218 if let Some(expiry) = login_info.expires_unix() {
219 Utc::now().timestamp() < expiry
220 } else {
221 true
222 }
223 });
224 }
225}
226
227#[cfg(test)]
228mod tests {
229
230 use crate::config::CliConfig;
231
232 use super::*;
233
234 use oauth2::{basic::BasicTokenType, AccessToken, StandardTokenResponse};
235
236 #[test]
237 fn client_access_time_updated_existing() {
238 let config = CliConfig::default();
239 let state = GlobalAppState::new(&config).unwrap();
240
241 let now = OffsetDateTime::now_utc();
243
244 let session_expiration = now.checked_add(time::Duration::days(1)).unwrap();
246 let raw_session = tower_sessions::Session::new(Some(session_expiration));
247 let session_id = raw_session.id().to_string();
248
249 let mut session = Session::default();
250 session.session_id = session_id.clone();
251 session.session = raw_session;
252
253 let access_token = AccessToken::new("ABC".into());
254 let token_response = StandardTokenResponse::new(
255 access_token,
256 BasicTokenType::Bearer,
257 oauth2::EmptyExtraTokenFields {},
258 );
259 let expired_login_info =
261 LoginInfo::from_token(token_response, Some(now.unix_timestamp() - 1)).unwrap();
262
263 state
264 .login_info
265 .insert(session.session_id.clone(), expired_login_info.clone());
266
267 let session_arg = SessionArg::Session(session.clone());
268 state.create_client(&session_arg).unwrap();
269 assert_eq!(
271 Some(session_expiration.unix_timestamp()),
272 state.login_info.get(&session_id).unwrap().expires_unix()
273 );
274 }
275
276 #[test]
277 fn client_access_time_updated_set_from_session() {
278 let config = CliConfig::default();
279 let state = GlobalAppState::new(&config).unwrap();
280
281 let now = OffsetDateTime::now_utc();
283
284 let session_expiration = now.checked_add(time::Duration::days(1)).unwrap();
286 let raw_session = tower_sessions::Session::new(Some(session_expiration));
287 let session_id = raw_session.id().to_string();
288
289 let mut session = Session::default();
290 session.session_id = session_id.clone();
291 session.session = raw_session;
292
293 let access_token = AccessToken::new("ABC".into());
294 let token_response = StandardTokenResponse::new(
295 access_token,
296 BasicTokenType::Bearer,
297 oauth2::EmptyExtraTokenFields {},
298 );
299 let expired_login_info = LoginInfo::from_token(token_response, None).unwrap();
301
302 state
303 .login_info
304 .insert(session.session_id.clone(), expired_login_info.clone());
305
306 let session_arg = SessionArg::Session(session.clone());
307 state.create_client(&session_arg).unwrap();
308 assert_eq!(
310 Some(session_expiration.unix_timestamp()),
311 state.login_info.get(&session_id).unwrap().expires_unix()
312 );
313 }
314}