bevy_gotrue/
lib.rs

1/// Heavily inspired by (and partially fully stolen) from https://crates.io/crates/go_true
2mod builder;
3mod client;
4
5pub use builder::Builder;
6pub use client::Client;
7
8use std::io::{ErrorKind, Read, Write};
9use std::net::TcpListener;
10use std::ops::Deref;
11
12use bevy::prelude::*;
13use bevy::tasks::futures_lite::future;
14use bevy::tasks::{block_on, AsyncComputeTaskPool, Task};
15use bevy::utils::HashMap;
16use bevy_http_client::prelude::{
17    HttpTypedRequestTrait, TypedRequest, TypedResponse, TypedResponseError,
18};
19use bevy_http_client::{HttpClient, HttpClientSetting};
20use ehttp::Headers;
21use serde::Deserialize;
22use serde_json::Value;
23
24#[derive(Debug, Resource, Deserialize, Clone)]
25pub struct Session {
26    pub access_token: String,
27    pub token_type: String,
28    pub expires_in: i32,
29    pub refresh_token: String,
30    pub user: User,
31}
32
33#[derive(Debug, Resource, Deserialize, Default, Clone)]
34pub struct User {
35    pub id: String,
36    pub email: String,
37    pub aud: String,
38    pub role: String,
39    pub email_confirmed_at: Option<String>,
40    pub phone: String,
41    pub last_sign_in_at: Option<String>,
42    pub created_at: String,
43    pub updated_at: String,
44}
45
46#[derive(Debug, Resource)]
47pub struct UserAttributes {
48    pub email: String,
49    pub password: String,
50    pub data: Value,
51}
52
53pub struct UserList {
54    pub users: Vec<User>,
55}
56
57pub struct UserUpdate {
58    pub id: String,
59    pub email: String,
60    pub new_email: String,
61    pub email_change_sent_at: String,
62    pub created_at: String,
63    pub updated_at: String,
64}
65
66#[derive(Resource, Clone)]
67pub struct AuthCreds {
68    pub id: String,
69    pub password: String,
70}
71
72pub struct AuthPlugin {
73    pub endpoint: String,
74}
75
76impl AuthPlugin {
77    pub fn new(endpoint: String) -> Self {
78        Self { endpoint }
79    }
80}
81
82impl Plugin for AuthPlugin {
83    fn build(&self, app: &mut App) {
84        if !app.world().contains_resource::<HttpClientSetting>() {
85            panic!("Please load bevy_http_client::BevyHttpClient plugin!");
86        }
87        let headers = Headers::new(&[]);
88        let sign_in = app.world_mut().register_system(sign_in);
89
90        app.world_mut().insert_resource(Client {
91            endpoint: self.endpoint.clone(),
92            headers,
93            sign_in,
94            access_token: None,
95        });
96
97        app.add_systems(PreStartup, start_provider_server)
98            .add_systems(
99                Update,
100                (
101                    sign_in_recv,
102                    sign_in_err, // TODO runconditions
103                    poll_listener.run_if(resource_exists::<ProviderListener>),
104                ),
105            )
106            .register_request_type::<Session>();
107    }
108}
109
110// logged_in runcondition
111pub fn just_logged_in(session: Option<Res<Session>>) -> bool {
112    if let Some(session) = session {
113        session.is_added()
114    } else {
115        false
116    }
117}
118
119pub fn is_logged_in(session: Option<Res<Session>>) -> bool {
120    session.is_some()
121}
122
123#[derive(Resource)]
124struct ProviderListener(Task<Result<Session, std::io::Error>>);
125
126pub fn start_provider_server(mut commands: Commands) {
127    let pool = AsyncComputeTaskPool::get();
128    let task = pool.spawn(async {
129        let listener = TcpListener::bind("127.0.0.1:6969").expect("Couldn't bind port 6969.");
130
131        let mut params = HashMap::new();
132
133        loop {
134            let (mut stream, _) = listener.accept().expect("Listener IO error");
135
136            // This javascript is mental, I have to make fetch happen because GoTrue puts the
137            // access token in the URI hash? Like is that intentional, surely should be on search
138            // params. This fix does require JS in browser but most oAuth sign in pages probably do too, so
139            // should be a non-issue.
140            let message = String::from(
141                "<script>fetch(`http://localhost:6969/token?${window.location.hash.replace('#','')})`)</script><h1>GoTrue-Rs</h1><h2>Signin sent to program.</h2><h3>You may close this tab.</h3>",
142            );
143
144            // TODO optional redirect to user provided URI
145
146            let res = format!(
147                "HTTP/1.1 200 OK\r\ncontent-length: {}\r\n\r\n{}",
148                message.len(),
149                message
150            );
151
152            loop {
153                match stream.write(res.as_bytes()) {
154                    Ok(_n) => break,
155                    Err(ref e) if e.kind() == ErrorKind::WouldBlock => continue,
156                    Err(e) => println!("Couldn't respond. {}", e),
157                }
158            }
159
160            let mut buf = [0; 4096];
161
162            loop {
163                match stream.read(&mut buf) {
164                    Ok(0) => break,
165                    Ok(_n) => break,
166                    Err(ref e) if e.kind() == ErrorKind::WouldBlock => continue,
167                    Err(e) => {
168                        return Err(e);
169                    }
170                }
171            }
172
173            let raw = String::from_utf8(buf.to_vec()).unwrap();
174
175            let request_line = raw.lines().collect::<Vec<_>>()[0];
176
177            if !request_line.starts_with("GET /token?") {
178                // If this request isn't the one we sent with JS fetch, ignore it and wait for the
179                // right one.
180                continue;
181            }
182
183            let split_req = request_line
184                .strip_prefix("GET /token?")
185                .unwrap()
186                .split('&')
187                .collect::<Vec<&str>>();
188
189            for param in split_req {
190                let split_param = param.split('=').collect::<Vec<&str>>();
191                params.insert(split_param[0].to_owned(), split_param[1].to_owned());
192            }
193
194            if params.get("access_token").is_some() {
195                break;
196            }
197        }
198
199        let access_token = params.get("access_token").unwrap().clone();
200        let refresh_token = params.get("refresh_token").unwrap().clone();
201        let token_type = "JWT".to_string();
202        let expires_in:i32 = params.get("expires_in").unwrap_or(&"3600".to_string()).clone().parse().unwrap();
203
204        let session = Session {
205            access_token,
206            refresh_token,
207            token_type,
208            expires_in,
209            user: User::default(),
210        };
211
212        Ok(session)
213    });
214
215    commands.insert_resource(ProviderListener(task));
216}
217
218fn poll_listener(mut commands: Commands, mut task: ResMut<ProviderListener>) {
219    if let Some(Ok(result)) = block_on(future::poll_once(&mut task.0)) {
220        commands.insert_resource(result);
221        commands.remove_resource::<ProviderListener>();
222    }
223}
224
225// Oneshot
226pub fn sign_in(
227    In(creds): In<AuthCreds>,
228    auth: Res<Client>,
229    mut evw: EventWriter<TypedRequest<Session>>,
230) {
231    let req = auth
232        .builder()
233        .sign_in(builder::EmailOrPhone::Email(creds.id), creds.password);
234
235    let req = HttpClient::new().request(req).with_type::<Session>();
236    evw.send(req);
237}
238
239fn sign_in_recv(
240    mut evr: EventReader<TypedResponse<Session>>,
241    mut client: ResMut<Client>,
242    mut commands: Commands,
243) {
244    for res in evr.read() {
245        let session = res.deref();
246        client.access_token = Some(session.access_token.clone());
247        commands.insert_resource(session.clone());
248    }
249}
250
251fn sign_in_err(mut evr: EventReader<TypedResponseError<Session>>) {
252    for err in evr.read() {
253        println!("Login error: {:?}", err);
254    }
255}