qcs_api_client_common/configuration/
py.rs

1#![allow(unused_qualifications)]
2#![allow(non_local_definitions, reason = "necessary for pyo3::pymethods")]
3
4use pyo3::{
5    exceptions::{PyFileNotFoundError, PyOSError, PyRuntimeError, PyValueError},
6    prelude::*,
7    types::PyFunction,
8};
9use pyo3_asyncio::tokio::get_runtime;
10use rigetti_pyo3::{create_init_submodule, py_function_sync_async};
11use tokio_util::sync::CancellationToken;
12
13use crate::{
14    configuration::{
15        secrets::{DEFAULT_SECRETS_PATH, SECRETS_PATH_VAR},
16        settings::{DEFAULT_SETTINGS_PATH, SETTINGS_PATH_VAR},
17        API_URL_VAR, DEFAULT_API_URL, DEFAULT_GRPC_API_URL, DEFAULT_PROFILE_NAME,
18        DEFAULT_QUILC_URL, DEFAULT_QVM_URL, GRPC_API_URL_VAR, PROFILE_NAME_VAR, QUILC_URL_VAR,
19        QVM_URL_VAR,
20    },
21    impl_eq, impl_repr,
22};
23
24use super::{
25    error::TokenError,
26    secrets::{SecretAccessToken, SecretRefreshToken},
27    settings::AuthServer,
28    tokens::{ClientCredentials, ClientSecret, ExternallyManaged, PkceFlow},
29    ClientConfiguration, ClientConfigurationBuilder, LoadError, OAuthGrant, OAuthSession,
30    RefreshToken, TokenDispatcher,
31};
32
33create_init_submodule! {
34    classes: [
35        ClientConfiguration,
36        PyClientConfigurationBuilder,
37        AuthServer,
38        OAuthSession,
39        RefreshToken,
40        ClientCredentials,
41        ClientSecret,
42        ExternallyManaged,
43        PkceFlow,
44        SecretAccessToken,
45        SecretRefreshToken
46    ],
47    consts: [
48        DEFAULT_API_URL,
49        DEFAULT_GRPC_API_URL,
50        DEFAULT_QUILC_URL,
51        DEFAULT_QVM_URL,
52        DEFAULT_PROFILE_NAME,
53        PROFILE_NAME_VAR,
54        QUILC_URL_VAR,
55        QVM_URL_VAR,
56        API_URL_VAR,
57        GRPC_API_URL_VAR,
58        SETTINGS_PATH_VAR,
59        DEFAULT_SETTINGS_PATH,
60        SECRETS_PATH_VAR,
61        DEFAULT_SECRETS_PATH
62    ],
63}
64
65impl_eq!(RefreshToken);
66impl_repr!(RefreshToken);
67#[pymethods]
68impl RefreshToken {
69    #[new]
70    const fn __new__(refresh_token: SecretRefreshToken) -> Self {
71        Self::new(refresh_token)
72    }
73
74    #[getter]
75    #[pyo3(name = "refresh_token")]
76    fn py_refresh_token(&self) -> SecretRefreshToken {
77        self.refresh_token.clone()
78    }
79
80    #[setter]
81    #[pyo3(name = "refresh_token")]
82    fn py_set_refresh_token(&mut self, refresh_token: SecretRefreshToken) {
83        self.refresh_token = refresh_token;
84    }
85}
86
87impl_eq!(ClientCredentials);
88impl_repr!(ClientCredentials);
89#[pymethods]
90impl ClientCredentials {
91    #[new]
92    fn __new__(client_id: String, client_secret: String) -> Self {
93        Self::new(client_id, ClientSecret::from(client_secret))
94    }
95
96    #[getter]
97    #[pyo3(name = "client_id")]
98    fn py_client_id(&self) -> &str {
99        self.client_id()
100    }
101
102    #[getter]
103    #[pyo3(name = "client_secret")]
104    fn py_client_secret(&self) -> ClientSecret {
105        self.client_secret().clone()
106    }
107}
108
109impl_repr!(ExternallyManaged);
110#[pymethods]
111impl ExternallyManaged {
112    #[new]
113    fn __new__(refresh_function: Py<PyFunction>) -> Self {
114        #[allow(trivial_casts)] // Compilation fails without the cast.
115        // The provided refresh function will panic if there is an issue with the refresh function.
116        // This raises a `PanicException` within Python.
117        let refresh_closure = move |auth_server: AuthServer| {
118            let refresh_function = refresh_function.clone();
119            Box::pin(async move {
120                Python::with_gil(|py| {
121                    let result = refresh_function.call1(py, (auth_server.into_py(py),));
122                    match result {
123                        Ok(value) => value
124                            .extract::<String>(py)
125                            .map_or_else(|_| panic!("ExternallyManaged refresh function returned an unexpected type. Expected a string, got {value:?}"), Ok),
126                        Err(err) => Err(Box::<dyn std::error::Error + Send + Sync>::from(err))
127                    }
128                })
129            }) as super::tokens::RefreshResult
130        };
131
132        Self::new(refresh_closure)
133    }
134}
135
136impl_eq!(PkceFlow);
137// Does not implement `__repr__`, since the data contains a secret value.
138#[pymethods]
139impl PkceFlow {
140    #[new]
141    fn __new__(py: Python<'_>, auth_server: AuthServer) -> PyResult<Self> {
142        py.allow_threads(move || {
143            let runtime = get_runtime();
144            runtime.block_on(async move {
145                let cancel_token = cancel_token_with_ctrl_c();
146                Self::new_login_flow(cancel_token, &auth_server).await
147            })
148        })
149        .map_err(|err| PyRuntimeError::new_err(err.to_string()))
150    }
151
152    #[getter]
153    #[pyo3(name = "access_token")]
154    fn py_access_token(&self) -> SecretAccessToken {
155        self.access_token.clone()
156    }
157
158    #[getter]
159    #[pyo3(name = "refresh_token")]
160    fn py_refresh_token(&self) -> Option<SecretRefreshToken> {
161        self.refresh_token
162            .as_ref()
163            .map(|rt| rt.refresh_token.clone())
164    }
165}
166
167impl_repr!(OAuthSession);
168#[pymethods]
169impl OAuthSession {
170    #[new]
171    const fn __new__(
172        payload: OAuthGrant,
173        auth_server: AuthServer,
174        access_token: Option<SecretAccessToken>,
175    ) -> Self {
176        Self::new(payload, auth_server, access_token)
177    }
178
179    #[getter]
180    #[pyo3(name = "access_token")]
181    fn py_access_token(&self) -> Result<SecretAccessToken, TokenError> {
182        self.access_token().cloned()
183    }
184
185    #[getter]
186    #[pyo3(name = "payload")]
187    fn py_payload(&self, py: Python<'_>) -> PyObject {
188        match self.payload() {
189            OAuthGrant::ClientCredentials(ref client_credentials) => {
190                client_credentials.clone().into_py(py)
191            }
192            OAuthGrant::RefreshToken(ref refresh_token) => refresh_token.clone().into_py(py),
193            OAuthGrant::ExternallyManaged(ref externally_managed) => {
194                externally_managed.clone().into_py(py)
195            }
196            OAuthGrant::PkceFlow(ref pkce_tokens) => pkce_tokens.clone().into_py(py),
197        }
198    }
199
200    #[getter]
201    #[pyo3(name = "auth_server")]
202    fn py_auth_server(&self) -> AuthServer {
203        self.auth_server().clone()
204    }
205
206    #[pyo3(name = "validate")]
207    fn py_validate(&self) -> Result<SecretAccessToken, TokenError> {
208        self.validate()
209    }
210
211    #[pyo3(name = "request_access_token")]
212    fn py_request_access_token(&self, py: Python<'_>) -> PyResult<SecretAccessToken> {
213        py_request_access_token(py, self.clone())
214    }
215
216    #[pyo3(name = "request_access_token_async")]
217    fn py_request_access_token_async<'a>(&'a self, py: Python<'a>) -> PyResult<&'a PyAny> {
218        py_request_access_token_async(py, self.clone())
219    }
220}
221
222py_function_sync_async! {
223    #[pyfunction]
224    async fn get_oauth_session(tokens: Option<TokenDispatcher>) -> PyResult<OAuthSession> {
225        Ok(tokens.ok_or(TokenError::NoRefreshToken)?.tokens().await)
226    }
227}
228
229py_function_sync_async! {
230    #[pyfunction]
231    async fn get_bearer_access_token(configuration: ClientConfiguration) -> PyResult<SecretAccessToken> {
232        configuration.get_bearer_access_token().await.map_err(PyErr::from)
233    }
234}
235
236py_function_sync_async! {
237    #[pyfunction]
238    async fn request_access_token(session: OAuthSession) -> PyResult<SecretAccessToken> {
239        session.clone().request_access_token().await.cloned().map_err(PyErr::from)
240    }
241}
242
243impl_repr!(ClientConfiguration);
244#[pymethods]
245impl ClientConfiguration {
246    #[staticmethod]
247    #[pyo3(name = "load_default")]
248    fn py_load_default(_py: Python<'_>) -> Result<Self, LoadError> {
249        Self::load_default()
250    }
251
252    #[staticmethod]
253    #[pyo3(name = "load_default_with_login")]
254    fn py_load_default_with_login(py: Python<'_>) -> PyResult<Self> {
255        py.allow_threads(move || {
256            let runtime = get_runtime();
257            runtime.block_on(async move {
258                let cancel_token = cancel_token_with_ctrl_c();
259                Self::load_with_login(cancel_token, None).await
260            })
261        })
262        .map_err(|err| PyRuntimeError::new_err(err.to_string()))
263    }
264
265    #[staticmethod]
266    #[pyo3(name = "builder")]
267    fn py_builder() -> PyClientConfigurationBuilder {
268        PyClientConfigurationBuilder::default()
269    }
270
271    #[staticmethod]
272    #[pyo3(name = "load_profile")]
273    fn py_load_profile(_py: Python<'_>, profile_name: String) -> Result<Self, LoadError> {
274        Self::load_profile(profile_name)
275    }
276
277    #[getter]
278    fn get_api_url(&self) -> &str {
279        &self.api_url
280    }
281
282    #[getter]
283    fn get_grpc_api_url(&self) -> &str {
284        &self.grpc_api_url
285    }
286
287    #[getter]
288    fn get_quilc_url(&self) -> &str {
289        &self.quilc_url
290    }
291
292    #[getter]
293    fn get_qvm_url(&self) -> &str {
294        &self.qvm_url
295    }
296
297    #[pyo3(name = "get_bearer_access_token")]
298    fn py_get_bearer_access_token(&self, py: Python<'_>) -> PyResult<SecretAccessToken> {
299        py_get_bearer_access_token(py, self.clone())
300    }
301
302    #[pyo3(name = "get_bearer_access_token_async")]
303    fn py_get_bearer_access_token_async<'a>(&self, py: Python<'a>) -> PyResult<&'a PyAny> {
304        py_get_bearer_access_token_async(py, self.clone())
305    }
306
307    /// Get the configured tokens.
308    ///
309    /// # Errors
310    ///
311    /// - Raises a `TokenError` if there is a problem fetching the tokens
312    pub fn get_oauth_session(&self, py: Python<'_>) -> PyResult<OAuthSession> {
313        py_get_oauth_session(py, self.oauth_session.clone())
314    }
315
316    #[allow(clippy::needless_pass_by_value)] // self_ must be passed by value
317    fn get_oauth_session_async<'a>(
318        self_: PyRefMut<'a, Self>,
319        py: Python<'a>,
320    ) -> PyResult<&'a PyAny> {
321        py_get_oauth_session_async(py, self_.oauth_session.clone())
322    }
323}
324
325#[pyclass]
326#[pyo3(name = "ClientConfigurationBuilder")]
327#[derive(Clone, Default)]
328struct PyClientConfigurationBuilder(ClientConfigurationBuilder);
329
330#[pymethods]
331impl PyClientConfigurationBuilder {
332    #[new]
333    fn new() -> Self {
334        Self::default()
335    }
336
337    fn build(&self) -> Result<ClientConfiguration, LoadError> {
338        Ok(self.0.build()?)
339    }
340
341    #[setter]
342    fn api_url(&mut self, api_url: String) {
343        self.0.api_url(api_url);
344    }
345
346    #[setter]
347    fn grpc_api_url(&mut self, grpc_api_url: String) {
348        self.0.grpc_api_url(grpc_api_url);
349    }
350
351    #[setter]
352    fn quilc_url(&mut self, quilc_url: String) {
353        self.0.quilc_url(quilc_url);
354    }
355
356    #[setter]
357    fn qvm_url(&mut self, qvm_url: String) {
358        self.0.qvm_url(qvm_url);
359    }
360
361    #[setter]
362    fn oauth_session(&mut self, oauth_session: Option<OAuthSession>) {
363        self.0.oauth_session(oauth_session);
364    }
365}
366
367impl_repr!(AuthServer);
368impl_eq!(AuthServer);
369#[pymethods]
370impl AuthServer {
371    #[new]
372    const fn __new__(client_id: String, issuer: String, scopes: Option<Vec<String>>) -> Self {
373        Self::new(client_id, issuer, scopes)
374    }
375
376    #[staticmethod]
377    #[pyo3(name = "default")]
378    fn py_default() -> Self {
379        Self::default()
380    }
381
382    /// Get the configured OAuth OIDC client id.
383    #[getter]
384    #[must_use]
385    pub fn get_client_id(&self) -> &str {
386        &self.client_id
387    }
388
389    /// Set an OAuth OIDC client id.
390    #[setter(client_id)]
391    pub fn py_set_client_id(&mut self, client_id: String) {
392        self.client_id = client_id;
393    }
394
395    /// Get the OAuth OIDC issuer URL.
396    #[getter]
397    #[must_use]
398    pub fn get_issuer(&self) -> &str {
399        &self.issuer
400    }
401
402    /// Set an OAuth OIDC issuer URL.
403    #[setter(issuer)]
404    pub fn py_set_issuer(&mut self, issuer: String) {
405        self.issuer = issuer;
406    }
407}
408
409impl From<LoadError> for PyErr {
410    fn from(value: LoadError) -> Self {
411        let message = value.to_string();
412        match value {
413            LoadError::Load(_)
414            | LoadError::Build(_)
415            | LoadError::ProfileNotFound(_)
416            | LoadError::AuthServerNotFound(_)
417            | LoadError::PkceFlow(_) => PyValueError::new_err(message),
418            LoadError::EnvVar { .. } | LoadError::Io(_) => PyOSError::new_err(message),
419            LoadError::Path { .. } => PyFileNotFoundError::new_err(message),
420            #[cfg(feature = "tracing-config")]
421            LoadError::TracingFilterParseError(_) => PyValueError::new_err(message),
422        }
423    }
424}
425
426impl From<TokenError> for PyErr {
427    fn from(value: TokenError) -> Self {
428        let message = value.to_string();
429        match value {
430            TokenError::NoRefreshToken
431            | TokenError::NoCredentials
432            | TokenError::NoAccessToken
433            | TokenError::NoAuthServer
434            | TokenError::InvalidAccessToken(_)
435            | TokenError::Fetch(_)
436            | TokenError::ExternallyManaged(_)
437            | TokenError::Write(_)
438            | TokenError::Discovery(_) => PyValueError::new_err(message),
439        }
440    }
441}
442
443fn cancel_token_with_ctrl_c() -> CancellationToken {
444    let cancel_token = CancellationToken::new();
445    let cancel_token_ctrl_c = cancel_token.clone();
446    tokio::spawn(cancel_token.clone().run_until_cancelled_owned(async move {
447        match tokio::signal::ctrl_c().await {
448            Ok(()) => cancel_token_ctrl_c.cancel(),
449            Err(error) => eprintln!("Failed to register signal handler: {error}"),
450        }
451    }));
452    cancel_token
453}