1pub use multilink::http::*;
2
3use multilink::{
4 error::ProtocolErrorType,
5 http::hyper::{Body, Method, StatusCode, Uri},
6 http::util::{
7 notification_sse_response, notification_sse_stream, parse_request, parse_response,
8 serialize_to_http_request, serialize_to_http_response, validate_method,
9 },
10 util::parse_from_value,
11 ProtocolError, ServiceResponse,
12};
13use serde_json::Value;
14
15use crate::service::{BackendRequest, BackendResponse, CoreRequest, CoreResponse};
16
17const GENERATE_PATH: &str = "/generate";
18const GENERATE_STREAM_PATH: &str = "/generate_stream";
19const GET_LAST_THREAD_INFO_METHOD: &str = "/threads/last";
20const GET_THREAD_MESSAGES_METHOD_PREFIX: &str = "/threads/";
21const GET_ALL_THREAD_INFOS_METHOD: &str = "/threads";
22
23#[async_trait::async_trait]
24impl RequestHttpConvert<CoreRequest> for CoreRequest {
25 async fn from_http_request(request: HttpRequest<Body>) -> Result<Option<Self>, ProtocolError> {
26 let path = request.uri().path();
27 let request = match path {
28 GENERATE_PATH => {
29 validate_method(&request, Method::POST)?;
30 CoreRequest::Generation(parse_request(request).await?)
31 }
32 GENERATE_STREAM_PATH => {
33 validate_method(&request, Method::POST)?;
34 CoreRequest::GenerationStream(parse_request(request).await?)
35 }
36 GET_ALL_THREAD_INFOS_METHOD => {
37 validate_method(&request, Method::GET)?;
38 CoreRequest::GetAllThreadInfos
39 }
40 GET_LAST_THREAD_INFO_METHOD => {
41 validate_method(&request, Method::GET)?;
42 CoreRequest::GetLastThreadInfo
43 }
44 _ => {
45 if !path.starts_with(GET_THREAD_MESSAGES_METHOD_PREFIX)
46 || request.method() != &Method::GET
47 || path.split(&['/', '\\']).count() != 2
48 {
49 return Ok(None);
50 }
51 let id = path.split(&['/', '\\']).nth(1).unwrap();
52 CoreRequest::GetThreadMessages { id: id.to_string() }
53 }
54 };
55 Ok(Some(request))
56 }
57
58 fn to_http_request(&self, base_url: &Uri) -> Result<Option<HttpRequest<Body>>, ProtocolError> {
59 let request = match self {
60 CoreRequest::Generation(request) => {
61 serialize_to_http_request(base_url, GENERATE_PATH, Method::POST, &request)?
62 }
63 CoreRequest::GenerationStream(request) => {
64 serialize_to_http_request(base_url, GENERATE_STREAM_PATH, Method::POST, &request)?
65 }
66 CoreRequest::GetLastThreadInfo => serialize_to_http_request(
67 base_url,
68 GET_LAST_THREAD_INFO_METHOD,
69 Method::GET,
70 &Value::Null,
71 )?,
72 CoreRequest::GetAllThreadInfos => serialize_to_http_request(
73 base_url,
74 GET_ALL_THREAD_INFOS_METHOD,
75 Method::GET,
76 &Value::Null,
77 )?,
78 CoreRequest::GetThreadMessages { id } => {
79 serialize_to_http_request(base_url, GET_ALL_THREAD_INFOS_METHOD, Method::GET, &id)?
80 }
81 _ => return Ok(None),
82 };
83 Ok(Some(request))
84 }
85}
86
87#[async_trait::async_trait]
88impl ResponseHttpConvert<CoreRequest, CoreResponse> for CoreResponse {
89 async fn from_http_response(
90 response: ModalHttpResponse,
91 original_request: &CoreRequest,
92 ) -> Result<Option<ServiceResponse<Self>>, ProtocolError> {
93 Ok(Some(match response {
94 ModalHttpResponse::Single(response) => match original_request {
95 CoreRequest::Generation(_) => ServiceResponse::Single(CoreResponse::Generation(
96 parse_response(response).await?,
97 )),
98 CoreRequest::GenerationStream(_) => ServiceResponse::Multiple(
99 notification_sse_stream(original_request.clone(), response),
100 ),
101 CoreRequest::GetLastThreadInfo => ServiceResponse::Single(
102 CoreResponse::GetLastThreadInfo(parse_response(response).await?),
103 ),
104 CoreRequest::GetAllThreadInfos => ServiceResponse::Single(
105 CoreResponse::GetAllThreadInfos(parse_response(response).await?),
106 ),
107 CoreRequest::GetThreadMessages { .. } => ServiceResponse::Single(
108 CoreResponse::GetThreadMessages(parse_response(response).await?),
109 ),
110 _ => return Ok(None),
111 },
112 ModalHttpResponse::Event(event) => ServiceResponse::Single(match original_request {
113 CoreRequest::GenerationStream(_) => {
114 CoreResponse::GenerationStream(parse_from_value(event)?)
115 }
116 _ => return Ok(None),
117 }),
118 }))
119 }
120
121 fn to_http_response(
122 response: ServiceResponse<Self>,
123 ) -> Result<Option<ModalHttpResponse>, ProtocolError> {
124 let response = match response {
125 ServiceResponse::Single(response) => match response {
126 CoreResponse::Generation(response) => ModalHttpResponse::Single(
127 serialize_to_http_response(&response, StatusCode::OK)?,
128 ),
129 CoreResponse::GenerationStream(response) => {
130 ModalHttpResponse::Event(serde_json::to_value(response).unwrap())
131 }
132 CoreResponse::GetLastThreadInfo(response) => ModalHttpResponse::Single(
133 serialize_to_http_response(&response, StatusCode::OK)?,
134 ),
135 CoreResponse::GetAllThreadInfos(response) => ModalHttpResponse::Single(
136 serialize_to_http_response(&response, StatusCode::OK)?,
137 ),
138 CoreResponse::GetThreadMessages(response) => ModalHttpResponse::Single(
139 serialize_to_http_response(&response, StatusCode::OK)?,
140 ),
141 _ => return Ok(None),
142 },
143 ServiceResponse::Multiple(stream) => {
144 ModalHttpResponse::Single(notification_sse_response(stream))
145 }
146 };
147 Ok(Some(response))
148 }
149}
150
151#[async_trait::async_trait]
152impl RequestHttpConvert<BackendRequest> for BackendRequest {
153 async fn from_http_request(request: HttpRequest<Body>) -> Result<Option<Self>, ProtocolError> {
154 let request = match request.uri().path() {
155 GENERATE_PATH => match request.method() == &Method::POST {
156 true => BackendRequest::Generation(parse_request(request).await?),
157 false => return Err(generic_error(ProtocolErrorType::HttpMethodNotAllowed).into()),
158 },
159 GENERATE_STREAM_PATH => match request.method() == &Method::POST {
160 true => BackendRequest::GenerationStream(parse_request(request).await?),
161 false => return Err(generic_error(ProtocolErrorType::HttpMethodNotAllowed).into()),
162 },
163 _ => return Ok(None),
164 };
165 Ok(Some(request))
166 }
167
168 fn to_http_request(&self, base_url: &Uri) -> Result<Option<HttpRequest<Body>>, ProtocolError> {
169 let request = match self {
170 BackendRequest::Generation(request) => {
171 serialize_to_http_request(base_url, GENERATE_PATH, Method::POST, &request)?
172 }
173 BackendRequest::GenerationStream(request) => {
174 serialize_to_http_request(base_url, GENERATE_STREAM_PATH, Method::POST, &request)?
175 }
176 };
177 Ok(Some(request))
178 }
179}
180
181#[async_trait::async_trait]
182impl ResponseHttpConvert<BackendRequest, BackendResponse> for BackendResponse {
183 async fn from_http_response(
184 response: ModalHttpResponse,
185 original_request: &BackendRequest,
186 ) -> Result<Option<ServiceResponse<Self>>, ProtocolError> {
187 let response = match response {
188 ModalHttpResponse::Single(response) => match original_request {
189 BackendRequest::Generation(_) => ServiceResponse::Single(
190 BackendResponse::Generation(parse_response(response).await?),
191 ),
192 BackendRequest::GenerationStream(_) => ServiceResponse::Multiple(
193 notification_sse_stream(original_request.clone(), response),
194 ),
195 },
196 ModalHttpResponse::Event(event) => ServiceResponse::Single(match original_request {
197 BackendRequest::GenerationStream(_) => {
198 BackendResponse::GenerationStream(parse_from_value(event)?)
199 }
200 _ => return Ok(None),
201 }),
202 };
203 Ok(Some(response))
204 }
205
206 fn to_http_response(
207 response: ServiceResponse<Self>,
208 ) -> Result<Option<ModalHttpResponse>, ProtocolError> {
209 Ok(Some(match response {
210 ServiceResponse::Single(response) => match response {
211 BackendResponse::Generation(response) => ModalHttpResponse::Single(
212 serialize_to_http_response(&response, StatusCode::OK)?,
213 ),
214 BackendResponse::GenerationStream(response) => {
215 ModalHttpResponse::Event(serde_json::to_value(response).unwrap())
216 }
217 },
218 ServiceResponse::Multiple(stream) => {
219 ModalHttpResponse::Single(notification_sse_response(stream))
220 }
221 }))
222 }
223}