qcs_api_client_common/configuration/
py.rs

1#![allow(unused_qualifications)]
2
3use pyo3::{
4    exceptions::{PyFileNotFoundError, PyOSError, PyValueError},
5    prelude::*,
6    types::PyFunction,
7};
8use rigetti_pyo3::{create_init_submodule, py_function_sync_async};
9
10use crate::{
11    configuration::{
12        API_URL_VAR, DEFAULT_API_URL, DEFAULT_GRPC_API_URL, DEFAULT_PROFILE_NAME,
13        DEFAULT_QUILC_URL, DEFAULT_QVM_URL, DEFAULT_SECRETS_PATH, DEFAULT_SETTINGS_PATH,
14        GRPC_API_URL_VAR, PROFILE_NAME_VAR, QUILC_URL_VAR, QVM_URL_VAR, SECRETS_PATH_VAR,
15        SETTINGS_PATH_VAR,
16    },
17    impl_eq, impl_repr,
18};
19
20use super::{
21    error::TokenError, settings::AuthServer, tokens::ClientCredentials, ClientConfiguration,
22    ClientConfigurationBuilder, ExternallyManaged, LoadError, OAuthGrant, OAuthSession,
23    RefreshToken, TokenDispatcher,
24};
25
26create_init_submodule! {
27    classes: [
28        ClientConfiguration,
29        PyClientConfigurationBuilder,
30        AuthServer,
31        OAuthSession,
32        RefreshToken,
33        ClientCredentials,
34        ExternallyManaged
35    ],
36    consts: [
37        DEFAULT_API_URL,
38        DEFAULT_GRPC_API_URL,
39        DEFAULT_QUILC_URL,
40        DEFAULT_QVM_URL,
41        DEFAULT_PROFILE_NAME,
42        PROFILE_NAME_VAR,
43        QUILC_URL_VAR,
44        QVM_URL_VAR,
45        API_URL_VAR,
46        GRPC_API_URL_VAR,
47        SETTINGS_PATH_VAR,
48        DEFAULT_SETTINGS_PATH,
49        SECRETS_PATH_VAR,
50        DEFAULT_SECRETS_PATH
51    ],
52}
53
54impl_eq!(RefreshToken);
55impl_repr!(RefreshToken);
56#[pymethods]
57impl RefreshToken {
58    #[new]
59    const fn __new__(refresh_token: String) -> Self {
60        Self::new(refresh_token)
61    }
62
63    #[getter]
64    #[pyo3(name = "refresh_token")]
65    fn py_refresh_token(&self) -> &str {
66        &self.refresh_token
67    }
68
69    #[setter]
70    #[pyo3(name = "refresh_token")]
71    fn py_set_refresh_token(&mut self, refresh_token: String) {
72        self.refresh_token = refresh_token;
73    }
74}
75
76impl_eq!(ClientCredentials);
77// Does not implement `__repr__`, since the data contains a secret value.
78#[pymethods]
79impl ClientCredentials {
80    #[new]
81    const fn __new__(client_id: String, client_secret: String) -> Self {
82        Self::new(client_id, client_secret)
83    }
84
85    #[getter]
86    #[pyo3(name = "client_id")]
87    fn py_client_id(&self) -> &str {
88        self.client_id()
89    }
90
91    #[getter]
92    #[pyo3(name = "client_secret")]
93    fn py_client_secret(&self) -> &str {
94        self.client_secret()
95    }
96}
97
98impl_repr!(ExternallyManaged);
99#[pymethods]
100impl ExternallyManaged {
101    #[new]
102    fn __new__(refresh_function: Py<PyFunction>) -> Self {
103        #[allow(trivial_casts)] // Compilation fails without the cast.
104        // The provided refresh function will panic if there is an issue with the refresh function.
105        // This raises a `PanicException` within Python.
106        let refresh_closure = move |auth_server: AuthServer| {
107            let refresh_function = refresh_function.clone();
108            Box::pin(async move {
109                Python::with_gil(|py| {
110                    let result = refresh_function.call1(py, (auth_server.into_py(py),));
111                    match result {
112                        Ok(value) => value
113                            .extract::<String>(py)
114                            .map_or_else(|_| panic!("ExternallyManaged refresh function returned an unexpected type. Expected a string, got {value:?}"), Ok),
115                        Err(err) => Err(Box::<dyn std::error::Error + Send + Sync>::from(err))
116                    }
117                })
118            }) as super::tokens::RefreshResult
119        };
120
121        Self::new(refresh_closure)
122    }
123}
124
125impl_repr!(OAuthSession);
126#[pymethods]
127impl OAuthSession {
128    #[new]
129    const fn __new__(
130        payload: OAuthGrant,
131        auth_server: AuthServer,
132        access_token: Option<String>,
133    ) -> Self {
134        Self::new(payload, auth_server, access_token)
135    }
136
137    #[getter]
138    #[pyo3(name = "access_token")]
139    fn py_access_token(&self) -> Result<&str, TokenError> {
140        self.access_token()
141    }
142
143    #[getter]
144    #[pyo3(name = "payload")]
145    fn py_payload(&self, py: Python<'_>) -> PyObject {
146        match self.payload() {
147            OAuthGrant::ClientCredentials(ref client_credentials) => {
148                client_credentials.clone().into_py(py)
149            }
150            OAuthGrant::RefreshToken(ref refresh_token) => refresh_token.clone().into_py(py),
151            OAuthGrant::ExternallyManaged(ref externally_managed) => {
152                externally_managed.clone().into_py(py)
153            }
154        }
155    }
156
157    #[getter]
158    #[pyo3(name = "auth_server")]
159    fn py_auth_server(&self) -> AuthServer {
160        self.auth_server().clone()
161    }
162
163    #[pyo3(name = "validate")]
164    fn py_validate(&self) -> Result<String, TokenError> {
165        self.validate()
166    }
167
168    #[pyo3(name = "request_access_token")]
169    fn py_request_access_token(&self, py: Python<'_>) -> PyResult<String> {
170        py_request_access_token(py, self.clone())
171    }
172
173    #[pyo3(name = "request_access_token_async")]
174    fn py_request_access_token_async<'a>(&'a self, py: Python<'a>) -> PyResult<&'a PyAny> {
175        py_request_access_token_async(py, self.clone())
176    }
177}
178
179py_function_sync_async! {
180    #[pyfunction]
181    async fn get_oauth_session(tokens: Option<TokenDispatcher>) -> PyResult<OAuthSession> {
182        Ok(tokens.ok_or(TokenError::NoRefreshToken)?.tokens().await)
183    }
184}
185
186py_function_sync_async! {
187    #[pyfunction]
188    async fn get_bearer_access_token(configuration: ClientConfiguration) -> PyResult<String> {
189        configuration.get_bearer_access_token().await.map_err(PyErr::from)
190    }
191}
192
193py_function_sync_async! {
194    #[pyfunction]
195    async fn request_access_token(session: OAuthSession) -> PyResult<String> {
196        session.clone().request_access_token().await.map(std::string::ToString::to_string).map_err(PyErr::from)
197    }
198}
199
200impl_repr!(ClientConfiguration);
201#[pymethods]
202impl ClientConfiguration {
203    #[staticmethod]
204    #[pyo3(name = "load_default")]
205    fn py_load_default(_py: Python<'_>) -> Result<Self, LoadError> {
206        Self::load_default()
207    }
208
209    #[staticmethod]
210    #[pyo3(name = "builder")]
211    fn py_builder() -> PyClientConfigurationBuilder {
212        PyClientConfigurationBuilder::default()
213    }
214
215    #[staticmethod]
216    #[pyo3(name = "load_profile")]
217    fn py_load_profile(_py: Python<'_>, profile_name: String) -> Result<Self, LoadError> {
218        Self::load_profile(profile_name)
219    }
220
221    #[getter]
222    fn get_api_url(&self) -> &str {
223        &self.api_url
224    }
225
226    #[getter]
227    fn get_grpc_api_url(&self) -> &str {
228        &self.grpc_api_url
229    }
230
231    #[getter]
232    fn get_quilc_url(&self) -> &str {
233        &self.quilc_url
234    }
235
236    #[getter]
237    fn get_qvm_url(&self) -> &str {
238        &self.qvm_url
239    }
240
241    #[pyo3(name = "get_bearer_access_token")]
242    fn py_get_bearer_access_token(&self, py: Python<'_>) -> PyResult<String> {
243        py_get_bearer_access_token(py, self.clone())
244    }
245
246    #[pyo3(name = "get_bearer_access_token_async")]
247    fn py_get_bearer_access_token_async<'a>(&self, py: Python<'a>) -> PyResult<&'a PyAny> {
248        py_get_bearer_access_token_async(py, self.clone())
249    }
250
251    /// Get the configured tokens.
252    ///
253    /// # Errors
254    ///
255    /// - Raises a TokenError if there is a problem fetching the tokens
256    pub fn get_oauth_session(&self, py: Python<'_>) -> PyResult<OAuthSession> {
257        py_get_oauth_session(py, self.oauth_session.clone())
258    }
259
260    #[allow(clippy::needless_pass_by_value)] // self_ must be passed by value
261    fn get_oauth_session_async<'a>(
262        self_: PyRefMut<'a, Self>,
263        py: Python<'a>,
264    ) -> PyResult<&'a PyAny> {
265        py_get_oauth_session_async(py, self_.oauth_session.clone())
266    }
267}
268
269#[pyclass]
270#[pyo3(name = "ClientConfigurationBuilder")]
271#[derive(Clone, Default)]
272struct PyClientConfigurationBuilder(ClientConfigurationBuilder);
273
274#[pymethods]
275impl PyClientConfigurationBuilder {
276    #[new]
277    fn new() -> Self {
278        Self::default()
279    }
280
281    fn build(&self) -> Result<ClientConfiguration, LoadError> {
282        Ok(self.0.build()?)
283    }
284
285    #[setter]
286    fn api_url(&mut self, api_url: String) {
287        self.0.api_url(api_url);
288    }
289
290    #[setter]
291    fn grpc_api_url(&mut self, grpc_api_url: String) {
292        self.0.grpc_api_url(grpc_api_url);
293    }
294
295    #[setter]
296    fn quilc_url(&mut self, quilc_url: String) {
297        self.0.quilc_url(quilc_url);
298    }
299
300    #[setter]
301    fn qvm_url(&mut self, qvm_url: String) {
302        self.0.qvm_url(qvm_url);
303    }
304
305    #[setter]
306    fn oauth_session(&mut self, oauth_session: Option<OAuthSession>) {
307        self.0.oauth_session(oauth_session);
308    }
309}
310
311impl_repr!(AuthServer);
312impl_eq!(AuthServer);
313#[pymethods]
314impl AuthServer {
315    #[new]
316    const fn py_new(client_id: String, issuer: String) -> Self {
317        Self::new(client_id, issuer)
318    }
319
320    #[staticmethod]
321    #[pyo3(name = "default")]
322    fn py_default() -> Self {
323        Self::default()
324    }
325
326    /// Get the configured Okta client id.
327    #[getter]
328    #[must_use]
329    pub fn get_client_id(&self) -> &str {
330        self.client_id()
331    }
332
333    /// Set an Okta client id.
334    #[setter(client_id)]
335    pub fn py_set_client_id(&mut self, id: String) {
336        self.set_client_id(id);
337    }
338
339    /// Get the Okta issuer URL.
340    #[getter]
341    #[must_use]
342    pub fn get_issuer(&self) -> &str {
343        self.issuer()
344    }
345
346    /// Set an Okta issuer URL.
347    #[setter(issuer)]
348    pub fn py_set_issuer(&mut self, issuer: String) {
349        self.set_issuer(issuer);
350    }
351}
352
353impl From<LoadError> for PyErr {
354    fn from(value: LoadError) -> Self {
355        let message = value.to_string();
356        match value {
357            LoadError::Load(_)
358            | LoadError::Build(_)
359            | LoadError::ProfileNotFound(_)
360            | LoadError::AuthServerNotFound(_) => PyValueError::new_err(message),
361            LoadError::EnvVar { .. } | LoadError::Io(_) => PyOSError::new_err(message),
362            LoadError::Path { .. } => PyFileNotFoundError::new_err(message),
363            #[cfg(feature = "tracing-config")]
364            LoadError::TracingFilterParseError(_) => PyValueError::new_err(message),
365        }
366    }
367}
368
369impl From<TokenError> for PyErr {
370    fn from(value: TokenError) -> Self {
371        let message = value.to_string();
372        match value {
373            TokenError::NoRefreshToken
374            | TokenError::NoCredentials
375            | TokenError::NoAccessToken
376            | TokenError::NoAuthServer
377            | TokenError::InvalidAccessToken(_)
378            | TokenError::Fetch(_)
379            | TokenError::ExternallyManaged(_)
380            | TokenError::Write(_) => PyValueError::new_err(message),
381        }
382    }
383}