pink_extension_runtime/
lib.rs

1use std::borrow::Cow;
2use std::io::Write;
3use std::{
4    fmt::Display,
5    str::FromStr,
6    time::{Duration, SystemTime},
7};
8
9use pink_extension::{
10    chain_extension::{
11        self as ext, HttpRequest, HttpRequestError, HttpResponse, JsCode, JsValue, PinkExtBackend,
12        SigType, StorageQuotaExceeded,
13    },
14    Balance, EcdhPublicKey, EcdsaPublicKey, EcdsaSignature, Hash,
15};
16use reqwest::{
17    header::{HeaderMap, HeaderName, HeaderValue},
18    Method,
19};
20use reqwest_env_proxy::EnvProxyBuilder;
21use sp_core::{ByteArray as _, Pair};
22
23pub mod local_cache;
24pub mod mock_ext;
25
26pub trait PinkRuntimeEnv {
27    type AccountId: AsRef<[u8]> + Display;
28
29    fn address(&self) -> &Self::AccountId;
30}
31
32pub struct DefaultPinkExtension<'a, T, Error> {
33    pub env: &'a T,
34    _e: std::marker::PhantomData<Error>,
35}
36
37impl<'a, T, E> DefaultPinkExtension<'a, T, E> {
38    pub fn new(env: &'a T) -> Self {
39        Self {
40            env,
41            _e: std::marker::PhantomData,
42        }
43    }
44}
45
46fn block_on<F: core::future::Future>(f: F) -> F::Output {
47    match tokio::runtime::Handle::try_current() {
48        Ok(handle) => handle.block_on(f),
49        Err(_) => tokio::runtime::Runtime::new()
50            .expect("Failed to create tokio runtime")
51            .block_on(f),
52    }
53}
54
55pub fn batch_http_request(requests: Vec<HttpRequest>, timeout_ms: u64) -> ext::BatchHttpResult {
56    const MAX_CONCURRENT_REQUESTS: usize = 5;
57    if requests.len() > MAX_CONCURRENT_REQUESTS {
58        return Err(ext::HttpRequestError::TooManyRequests);
59    }
60    block_on(async move {
61        let futs = requests
62            .into_iter()
63            .map(|request| async_http_request(request, timeout_ms));
64        tokio::time::timeout(
65            Duration::from_millis(timeout_ms + 200),
66            futures::future::join_all(futs),
67        )
68        .await
69    })
70    .or(Err(ext::HttpRequestError::Timeout))
71}
72
73pub fn http_request(
74    request: HttpRequest,
75    timeout_ms: u64,
76) -> Result<HttpResponse, HttpRequestError> {
77    use HttpRequestError::*;
78    match block_on(async_http_request(request, timeout_ms)) {
79        Ok(resp) => Ok(resp),
80        Err(err) => match err {
81            // runtime v1.0 supported errors
82            InvalidUrl | InvalidMethod | InvalidHeaderName | InvalidHeaderValue
83            | FailedToCreateClient | Timeout => Err(err),
84            _ => {
85                // To be compatible with runtime v1.0, we need to convert the v1.1 extended errors
86                // to an HTTP response with status code 524.
87                log::error!("chain_ext: http request failed: {}", err.display());
88                Ok(HttpResponse {
89                    status_code: 524,
90                    reason_phrase: "IO Error".into(),
91                    body: format!("{err:?}").into_bytes(),
92                    headers: vec![],
93                })
94            }
95        },
96    }
97}
98
99async fn async_http_request(
100    request: HttpRequest,
101    timeout_ms: u64,
102) -> Result<HttpResponse, HttpRequestError> {
103    if timeout_ms == 0 {
104        return Err(HttpRequestError::Timeout);
105    }
106    let timeout = Duration::from_millis(timeout_ms);
107    let url: reqwest::Url = request.url.parse().or(Err(HttpRequestError::InvalidUrl))?;
108    let client = reqwest::Client::builder()
109        .trust_dns(true)
110        .timeout(timeout)
111        .env_proxy(url.host_str().unwrap_or_default())
112        .build()
113        .or(Err(HttpRequestError::FailedToCreateClient))?;
114
115    let method: Method =
116        FromStr::from_str(request.method.as_str()).or(Err(HttpRequestError::InvalidMethod))?;
117    let mut headers = HeaderMap::new();
118    for (key, value) in &request.headers {
119        let key =
120            HeaderName::from_str(key.as_str()).or(Err(HttpRequestError::InvalidHeaderName))?;
121        let value = HeaderValue::from_str(value).or(Err(HttpRequestError::InvalidHeaderValue))?;
122        headers.insert(key, value);
123    }
124
125    let result = client
126        .request(method, url)
127        .headers(headers)
128        .body(request.body)
129        .send()
130        .await;
131
132    let mut response = match result {
133        Ok(response) => response,
134        Err(err) => {
135            // If there is somthing wrong with the network, we can not inspect the reason too
136            // much here. Let it return a non-standard 523 here.
137            return Ok(HttpResponse {
138                status_code: 523,
139                reason_phrase: "Unreachable".into(),
140                body: format!("{err:?}").into_bytes(),
141                headers: vec![],
142            });
143        }
144    };
145
146    let headers: Vec<_> = response
147        .headers()
148        .iter()
149        .map(|(k, v)| (k.to_string(), v.to_str().unwrap_or_default().into()))
150        .collect();
151
152    const MAX_BODY_SIZE: usize = 1024 * 1024 * 2; // 2MB
153
154    let mut body = Vec::new();
155    let mut writer = LimitedWriter::new(&mut body, MAX_BODY_SIZE);
156
157    while let Some(chunk) = response
158        .chunk()
159        .await
160        .or(Err(HttpRequestError::NetworkError))?
161    {
162        writer
163            .write_all(&chunk)
164            .or(Err(HttpRequestError::ResponseTooLarge))?;
165    }
166
167    let response = HttpResponse {
168        status_code: response.status().as_u16(),
169        reason_phrase: response
170            .status()
171            .canonical_reason()
172            .unwrap_or_default()
173            .into(),
174        body,
175        headers,
176    };
177    Ok(response)
178}
179
180impl<T: PinkRuntimeEnv, E: From<&'static str>> PinkExtBackend for DefaultPinkExtension<'_, T, E> {
181    type Error = E;
182    fn http_request(&self, request: HttpRequest) -> Result<HttpResponse, Self::Error> {
183        http_request(request, 10 * 1000).map_err(|err| err.display().into())
184    }
185
186    fn batch_http_request(
187        &self,
188        requests: Vec<HttpRequest>,
189        timeout_ms: u64,
190    ) -> Result<ext::BatchHttpResult, Self::Error> {
191        Ok(batch_http_request(requests, timeout_ms))
192    }
193
194    fn sign(
195        &self,
196        sigtype: SigType,
197        key: Cow<[u8]>,
198        message: Cow<[u8]>,
199    ) -> Result<Vec<u8>, Self::Error> {
200        macro_rules! sign_with {
201            ($sigtype:ident) => {{
202                let pair = sp_core::$sigtype::Pair::from_seed_slice(&key).or(Err("Invalid key"))?;
203                let signature = pair.sign(&message);
204                let signature: &[u8] = signature.as_ref();
205                signature.to_vec()
206            }};
207        }
208
209        Ok(match sigtype {
210            SigType::Sr25519 => sign_with!(sr25519),
211            SigType::Ed25519 => sign_with!(ed25519),
212            SigType::Ecdsa => sign_with!(ecdsa),
213        })
214    }
215
216    fn verify(
217        &self,
218        sigtype: SigType,
219        pubkey: Cow<[u8]>,
220        message: Cow<[u8]>,
221        signature: Cow<[u8]>,
222    ) -> Result<bool, Self::Error> {
223        macro_rules! verify_with {
224            ($sigtype:ident) => {{
225                let pubkey = sp_core::$sigtype::Public::from_slice(&pubkey)
226                    .map_err(|_| "Invalid public key")?;
227                let signature = sp_core::$sigtype::Signature::from_slice(&signature)
228                    .ok_or("Invalid signature")?;
229                Ok(sp_core::$sigtype::Pair::verify(
230                    &signature, message, &pubkey,
231                ))
232            }};
233        }
234        match sigtype {
235            SigType::Sr25519 => verify_with!(sr25519),
236            SigType::Ed25519 => verify_with!(ed25519),
237            SigType::Ecdsa => verify_with!(ecdsa),
238        }
239    }
240
241    fn derive_sr25519_key(&self, salt: Cow<[u8]>) -> Result<Vec<u8>, Self::Error> {
242        // This default implementation is for unit tests. The host should override this.
243        let mut seed: <sp_core::sr25519::Pair as Pair>::Seed = Default::default();
244        let len = seed.len().min(salt.len());
245        seed[..len].copy_from_slice(&salt[..len]);
246        let key = sp_core::sr25519::Pair::from_seed(&seed);
247
248        Ok(key.as_ref().secret.to_bytes().to_vec())
249    }
250
251    fn get_public_key(&self, sigtype: SigType, key: Cow<[u8]>) -> Result<Vec<u8>, Self::Error> {
252        macro_rules! public_key_with {
253            ($sigtype:ident) => {{
254                sp_core::$sigtype::Pair::from_seed_slice(&key)
255                    .or(Err("Invalid key"))?
256                    .public()
257                    .to_raw_vec()
258            }};
259        }
260        let pubkey = match sigtype {
261            SigType::Ed25519 => public_key_with!(ed25519),
262            SigType::Sr25519 => public_key_with!(sr25519),
263            SigType::Ecdsa => public_key_with!(ecdsa),
264        };
265        Ok(pubkey)
266    }
267
268    fn cache_set(
269        &self,
270        _key: Cow<[u8]>,
271        _value: Cow<[u8]>,
272    ) -> Result<Result<(), StorageQuotaExceeded>, Self::Error> {
273        Ok(Ok(()))
274    }
275
276    fn cache_set_expiration(&self, _key: Cow<[u8]>, _expire: u64) -> Result<(), Self::Error> {
277        Ok(())
278    }
279
280    fn cache_get(&self, _key: Cow<'_, [u8]>) -> Result<Option<Vec<u8>>, Self::Error> {
281        Ok(None)
282    }
283
284    fn cache_remove(&self, _key: Cow<'_, [u8]>) -> Result<Option<Vec<u8>>, Self::Error> {
285        Ok(None)
286    }
287
288    fn log(&self, level: u8, message: Cow<str>) -> Result<(), Self::Error> {
289        let address = self.env.address();
290        let level = match level {
291            1 => log::Level::Error,
292            2 => log::Level::Warn,
293            3 => log::Level::Info,
294            4 => log::Level::Debug,
295            5 => log::Level::Trace,
296            _ => log::Level::Error,
297        };
298        log::log!(target: "pink", level, "[{}] {}", address, message);
299        Ok(())
300    }
301
302    fn getrandom(&self, length: u8) -> Result<Vec<u8>, Self::Error> {
303        let mut buf = vec![0u8; length as _];
304        getrandom::getrandom(&mut buf[..]).or(Err("Failed to get random bytes"))?;
305        Ok(buf)
306    }
307
308    fn is_in_transaction(&self) -> Result<bool, Self::Error> {
309        Ok(false)
310    }
311
312    fn ecdsa_sign_prehashed(
313        &self,
314        key: Cow<[u8]>,
315        message_hash: Hash,
316    ) -> Result<EcdsaSignature, Self::Error> {
317        let pair = sp_core::ecdsa::Pair::from_seed_slice(&key).or(Err("Invalid key"))?;
318        let signature = pair.sign_prehashed(&message_hash);
319        Ok(signature.0)
320    }
321
322    fn ecdsa_verify_prehashed(
323        &self,
324        signature: EcdsaSignature,
325        message_hash: Hash,
326        pubkey: EcdsaPublicKey,
327    ) -> Result<bool, Self::Error> {
328        let public = sp_core::ecdsa::Public(pubkey);
329        let sig = sp_core::ecdsa::Signature(signature);
330        Ok(sp_core::ecdsa::Pair::verify_prehashed(
331            &sig,
332            &message_hash,
333            &public,
334        ))
335    }
336
337    fn system_contract_id(&self) -> Result<ext::AccountId, Self::Error> {
338        Err("No default system contract id".into())
339    }
340
341    fn balance_of(&self, _account: ext::AccountId) -> Result<(Balance, Balance), Self::Error> {
342        Ok((0, 0))
343    }
344
345    fn untrusted_millis_since_unix_epoch(&self) -> Result<u64, Self::Error> {
346        let duration = SystemTime::now()
347            .duration_since(SystemTime::UNIX_EPOCH)
348            .or(Err("The system time is earlier than UNIX_EPOCH"))?;
349        Ok(duration.as_millis() as u64)
350    }
351
352    fn worker_pubkey(&self) -> Result<EcdhPublicKey, Self::Error> {
353        Ok(Default::default())
354    }
355
356    fn code_exists(&self, _code_hash: Hash, _sidevm: bool) -> Result<bool, Self::Error> {
357        Ok(false)
358    }
359
360    fn import_latest_system_code(
361        &self,
362        _payer: ext::AccountId,
363    ) -> Result<Option<Hash>, Self::Error> {
364        Ok(None)
365    }
366
367    fn runtime_version(&self) -> Result<(u32, u32), Self::Error> {
368        Ok((1, 0))
369    }
370
371    fn current_event_chain_head(&self) -> Result<(u64, Hash), Self::Error> {
372        Ok((0, Default::default()))
373    }
374
375    fn js_eval(&self, _codes: Vec<JsCode>, _args: Vec<String>) -> Result<JsValue, Self::Error> {
376        Ok(JsValue::Exception("No Js Runtime".into()))
377    }
378}
379
380struct LimitedWriter<W> {
381    writer: W,
382    written: usize,
383    limit: usize,
384}
385
386impl<W> LimitedWriter<W> {
387    fn new(writer: W, limit: usize) -> Self {
388        Self {
389            writer,
390            written: 0,
391            limit,
392        }
393    }
394}
395
396impl<W: std::io::Write> std::io::Write for LimitedWriter<W> {
397    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
398        if self.written + buf.len() > self.limit {
399            return Err(std::io::Error::new(
400                std::io::ErrorKind::Other,
401                "Buffer limit exceeded",
402            ));
403        }
404        let wlen = self.writer.write(buf)?;
405        self.written += wlen;
406        Ok(wlen)
407    }
408
409    fn flush(&mut self) -> std::io::Result<()> {
410        self.writer.flush()
411    }
412}