ocpi/
client.rs

1use crate::{
2    context::ExtendedContext,
3    error::ServerError,
4    types::{self, CommandResult},
5    Context, Result, Session,
6};
7use http::HeaderMap;
8use reqwest::Method;
9
10#[derive(serde::Deserialize)]
11struct Reply<T> {
12    pub status_code: u32,
13
14    pub data: Option<T>,
15
16    #[serde(rename = "status_message")]
17    pub message: Option<String>,
18
19    #[allow(dead_code)]
20    pub timestamp: types::DateTime,
21}
22
23/// Implements the different OCPI calls as a client.
24pub struct Client {
25    http: reqwest::Client,
26}
27
28impl Clone for Client {
29    fn clone(&self) -> Self {
30        Self {
31            http: self.http.clone(),
32        }
33    }
34}
35
36impl Default for Client {
37    fn default() -> Self {
38        const CARGO_PKG_VERSION: &str = env!("CARGO_PKG_VERSION");
39
40        let mut def_headers = HeaderMap::new();
41        def_headers.append(
42            "user-agent",
43            format!("ocpi-rs {}", CARGO_PKG_VERSION)
44                .parse()
45                .expect("Invalid CARGO_PKG_VERSION"),
46        );
47
48        Self {
49            http: reqwest::Client::builder()
50                .default_headers(def_headers)
51                .build()
52                .expect("Building default OCPI client"),
53        }
54    }
55}
56
57impl Client {
58    pub fn new(http: reqwest::Client) -> Self {
59        Self { http }
60    }
61
62    fn req(
63        &self,
64        ctx: ExtendedContext<'_>,
65        method: reqwest::Method,
66        url: impl Into<String>,
67    ) -> ReqBuilder {
68        let url = url.into();
69        let b = self.http.request(method, &url).set_ocpi_ctx(ctx);
70        ReqBuilder { url, b }
71    }
72
73    /// Given a version number and the URL of the
74    /// root OCPI versions endpoint and the client.
75    /// The function first calls the versions URL to confirm
76    /// that the required version exists.
77    /// And then retrieves the available endpoints for that version.
78    pub async fn get_endpoints_for_version(
79        &self,
80        ctx: ExtendedContext<'_>,
81        versions_url: types::Url,
82        desired_version: types::VersionNumber,
83    ) -> Result<types::VersionDetails> {
84        let versions = self
85            .req(ctx, Method::GET, versions_url.clone())
86            .send::<Vec<types::Version>>()
87            .await?;
88
89        // Try to find the matching version.
90        let version = versions
91            .into_iter()
92            .find(|v| v.version == desired_version)
93            .ok_or(ServerError::IncompatibleEndpoints)?;
94
95        let version_details = self
96            .req(ctx, Method::GET, version.url.clone())
97            .send::<types::VersionDetails>()
98            .await?;
99
100        Ok(version_details)
101    }
102
103    pub async fn post_response(
104        &self,
105        ctx: ExtendedContext<'_>,
106        url: &url::Url,
107        command_result: CommandResult,
108    ) -> Result<()> {
109        self.req(ctx, Method::GET, url.to_string())
110            .body(&command_result)
111            .send()
112            .await
113    }
114
115    pub async fn put_session(&self, ctx: &Context, url: &url::Url, session: Session) -> Result<()> {
116        let cc = session.country_code.as_str();
117        let pid = session.party_id.as_str();
118        let id = session.id.as_str();
119
120        let url = format!("{url}?country_code={cc}&party_id={pid}&session_id={id}");
121
122        self.req(ctx.as_extended(), Method::PUT, url)
123            .body(&session)
124            .send()
125            .await
126    }
127}
128
129struct ReqBuilder {
130    // Storing url here for better errors.
131    url: String,
132    b: reqwest::RequestBuilder,
133}
134
135impl ReqBuilder {
136    fn body<T: serde::Serialize + ?Sized>(mut self, b: &T) -> Self {
137        self.b = self.b.json(b);
138        self
139    }
140
141    async fn send<T>(self) -> Result<T>
142    where
143        T: serde::de::DeserializeOwned,
144    {
145        let url = self.url;
146
147        let rep = self
148            .b
149            .send()
150            .await
151            .map_err(|err| ServerError::unusable_api(err.to_string()))?;
152
153        let status = rep.status();
154
155        if !status.is_success() {
156            return Err(ServerError::unusable_api(format!(
157                "Non 2xx-reply: {} from server",
158                status.as_u16()
159            )))?;
160        }
161
162        let body = rep.json::<Reply<T>>().await.map_err(|err| {
163            ServerError::unusable_api(format!(
164                "Error parsing result from `{}` as json: {}",
165                url, err
166            ))
167        })?;
168
169        let code = body.status_code / 100;
170        // 1000 is generic success.
171        // 19 is custom success range.
172        if !(code == 10 || code == 19) {
173            return Err(ServerError::unusable_api(format!(
174                "Non non success status_code `{} ({})` reply from `{}`. With message: `{}`",
175                body.status_code,
176                code,
177                url,
178                body.message.as_deref().unwrap_or("")
179            )))?;
180        }
181
182        match body.data {
183            Some(body) => Ok(body),
184            None => Err(ServerError::unusable_api(format!(
185                "Received unexpected empty body from `{}` with message: `{}`",
186                url,
187                body.message.as_deref().unwrap_or("")
188            )))?,
189        }
190    }
191}
192
193trait SetOcpiCtx {
194    fn set_ocpi_ctx(self, ctx: ExtendedContext<'_>) -> Self;
195}
196
197impl SetOcpiCtx for reqwest::RequestBuilder {
198    fn set_ocpi_ctx(self, ctx: ExtendedContext<'_>) -> Self {
199        let b64 = base64::encode(ctx.credentials_token.as_str());
200        self.header("Authorization", format!("Token {}", b64))
201            .header("X-Request-Id", ctx.request_id)
202            .header("X-Correlation-Id", ctx.correlation_id)
203    }
204}
205
206#[cfg(test)]
207mod tests {
208
209    use super::*;
210    use crate::context::test_ctx;
211    use serde_json::json;
212    use wiremock::{matchers, Mock, MockServer, ResponseTemplate};
213
214    #[tokio::test]
215    #[rustfmt::skip::macros(json)]
216    async fn test_endpoints_for_version() {
217        let cli = Client::default();
218        let mock = MockServer::start().await;
219
220        Mock::given(matchers::method("GET"))
221            .and(matchers::path("/versions"))
222            .respond_with(ResponseTemplate::new(200).set_body_json(json!(
223		{
224		    "status_code": 1000,
225		    "status_message": "Success",
226		    "timestamp": "2015-06-30T21:59:59Z",
227		    "data": 
228		    [
229			{
230			    "version": "2.1.1",
231			    "url": "http://www.server.com/ocpi/2.1.1/"
232			},
233			{
234			    "version": "2.2",
235			    "url": format!("{}/2.2", mock.uri())
236			}
237		    ]
238		}
239            )))
240            .mount(&mock)
241            .await;
242
243        Mock::given(matchers::method("GET"))
244            .and(matchers::path("/2.2"))
245            .respond_with(ResponseTemplate::new(200).set_body_json(json!(
246		{
247		    "status_code": 1000,
248		    "status_message": "Success",
249		    "timestamp": "2015-06-30T21:59:59Z",
250		    "data": {
251			"version": "2.2",
252			"endpoints": [
253			    {
254				"identifier": "credentials",
255				"role": "SENDER",
256				"url": format!("{}/2.2/credentials", mock.uri())
257			    }
258			]
259		    }
260		}
261	    )))
262            .mount(&mock)
263            .await;
264
265        let versions_url = format!("{}/versions", mock.uri())
266            .parse::<types::Url>()
267            .expect("Versions url");
268
269        let details = cli
270            .get_endpoints_for_version(
271                test_ctx().extend(&"imatoken".parse().unwrap()),
272                versions_url.clone(),
273                types::VersionNumber::V2_2,
274            )
275            .await
276            .expect(&format!("Making request to {}", versions_url));
277
278        assert_eq!(details.version, types::VersionNumber::V2_2);
279        assert_eq!(details.endpoints.len(), 1);
280        assert_eq!(
281            details.endpoints[0].identifier,
282            types::ModuleId::Credentials
283        );
284        assert_eq!(details.endpoints[0].role, types::InterfaceRole::Sender);
285    }
286}