use pyo3::{ffi::PyObject, prelude::*};
use std::io::{Read, Write};
use std::net::IpAddr;
use std::{ffi::c_int, slice};
use crate::ssl::{HandshakeError, Ssl, SslStream};
use error::ErrorStack;
pub use loader::{get, load};
use sys::*;
mod bio;
pub mod error;
mod loader;
pub mod ssl;
mod sys;
pub struct SSLContext {
ptr: *mut SSL_CTX,
pyobj: Py<PyAny>,
}
impl TryFrom<&Bound<'_, PyAny>> for SSLContext {
type Error = PyErr;
fn try_from(obj: &Bound<PyAny>) -> PyResult<Self> {
#[repr(C)]
struct PySSLContext {
ob_base: PyObject,
ctx: *mut SSL_CTX,
}
unsafe {
let ptr = obj.as_ptr() as *const PySSLContext;
let ptr = (*ptr).ctx;
if ptr.is_null() {
return Err(pyo3::exceptions::PyValueError::new_err(
"SSLContext has null SSL_CTX",
));
}
Ok(Self {
ptr,
pyobj: obj.clone().unbind(),
})
}
}
}
impl SSLContext {
pub fn connect<S>(&self, domain: &str, stream: S) -> Result<SslStream<S>, HandshakeError<S>>
where
S: Read + Write,
{
let mut ssl = Ssl::new(self.ptr)?;
if domain.parse::<IpAddr>().is_err() {
ssl.set_hostname(domain)?;
}
ssl.connect(stream)
}
pub fn accept<S>(&self, stream: S) -> Result<SslStream<S>, HandshakeError<S>>
where
S: Read + Write,
{
let ssl = Ssl::new(self.ptr)?;
ssl.accept(stream)
}
}
impl Clone for SSLContext {
fn clone(&self) -> Self {
Python::attach(|py| Self {
ptr: self.ptr,
pyobj: self.pyobj.clone_ref(py),
})
}
}
#[inline]
fn cvt_p<T>(r: *mut T) -> Result<*mut T, ErrorStack> {
if r.is_null() {
Err(ErrorStack::get())
} else {
Ok(r)
}
}
#[inline]
fn cvt(r: c_int) -> Result<c_int, ErrorStack> {
if r <= 0 {
Err(ErrorStack::get())
} else {
Ok(r)
}
}
unsafe fn from_raw_parts<'a, T>(data: *const T, len: usize) -> &'a [T] {
if len == 0 {
&[]
} else {
unsafe { slice::from_raw_parts(data, len) }
}
}
unsafe fn from_raw_parts_mut<'a, T>(data: *mut T, len: usize) -> &'a mut [T] {
if len == 0 {
&mut []
} else {
unsafe { slice::from_raw_parts_mut(data, len) }
}
}