1use std::sync::Arc;
2
3use hyper_util::rt::TokioIo;
4use tokio::net::UnixStream;
5use tonic::codegen::async_trait;
6use tonic::metadata::MetadataValue;
7use tonic::service::Interceptor;
8use tonic::service::interceptor::InterceptedService;
9use tonic::transport::{Channel, ClientTlsConfig, Endpoint, Uri};
10use tonic::{Request as GrpcRequest, Response as GrpcResponse, Status};
11use tower::service_fn;
12
13use crate::api::RuntimeMetadata;
14use crate::error::Result as ProviderResult;
15use crate::generated::v1::{
16 self as pb, agent_host_client::AgentHostClient as ProtoAgentHostClient,
17};
18
19type AgentHostTransport = InterceptedService<Channel, AgentHostRelayTokenInterceptor>;
20
21pub const ENV_AGENT_HOST_SOCKET: &str = "GESTALT_AGENT_HOST_SOCKET";
23pub const ENV_AGENT_HOST_SOCKET_TOKEN: &str = "GESTALT_AGENT_HOST_SOCKET_TOKEN";
25const AGENT_HOST_RELAY_TOKEN_HEADER: &str = "x-gestalt-host-service-relay-token";
26
27#[derive(Debug, thiserror::Error)]
28pub enum AgentHostError {
30 #[error("{0}")]
32 Transport(#[from] tonic::transport::Error),
33 #[error("{0}")]
35 Status(#[from] tonic::Status),
36 #[error("{0}")]
38 Env(String),
39}
40
41pub struct AgentHost {
43 client: ProtoAgentHostClient<AgentHostTransport>,
44}
45
46impl AgentHost {
47 pub async fn connect() -> std::result::Result<Self, AgentHostError> {
49 let target = std::env::var(ENV_AGENT_HOST_SOCKET)
50 .map_err(|_| AgentHostError::Env(format!("{ENV_AGENT_HOST_SOCKET} is not set")))?;
51 let relay_token = std::env::var(ENV_AGENT_HOST_SOCKET_TOKEN).unwrap_or_default();
52 let channel = match parse_agent_host_target(&target)? {
53 AgentHostTarget::Unix(path) => connect_unix(path).await?,
54 AgentHostTarget::Tcp(address) => {
55 Endpoint::from_shared(format!("http://{address}"))?
56 .connect()
57 .await?
58 }
59 AgentHostTarget::Tls(address) => {
60 Endpoint::from_shared(format!("https://{address}"))?
61 .tls_config(ClientTlsConfig::new().with_native_roots())?
62 .connect()
63 .await?
64 }
65 };
66 Ok(Self {
67 client: ProtoAgentHostClient::with_interceptor(
68 channel,
69 agent_host_relay_token_interceptor(relay_token.trim())?,
70 ),
71 })
72 }
73
74 pub async fn execute_tool(
76 &mut self,
77 request: pb::ExecuteAgentToolRequest,
78 ) -> std::result::Result<pb::ExecuteAgentToolResponse, AgentHostError> {
79 Ok(self.client.execute_tool(request).await?.into_inner())
80 }
81
82 pub async fn list_tools(
84 &mut self,
85 request: pb::ListAgentToolsRequest,
86 ) -> std::result::Result<pb::ListAgentToolsResponse, AgentHostError> {
87 Ok(self.client.list_tools(request).await?.into_inner())
88 }
89
90 pub async fn resolve_connection(
92 &mut self,
93 request: pb::ResolveAgentConnectionRequest,
94 ) -> std::result::Result<pb::ResolvedAgentConnection, AgentHostError> {
95 Ok(self.client.resolve_connection(request).await?.into_inner())
96 }
97}
98
99async fn connect_unix(
100 socket_path: String,
101) -> std::result::Result<Channel, tonic::transport::Error> {
102 Endpoint::try_from("http://[::]:50051")?
103 .connect_with_connector(service_fn(move |_: Uri| {
104 let path = socket_path.clone();
105 async move { UnixStream::connect(path).await.map(TokioIo::new) }
106 }))
107 .await
108}
109
110#[derive(Clone)]
111struct AgentHostRelayTokenInterceptor {
112 token: Option<MetadataValue<tonic::metadata::Ascii>>,
113}
114
115impl Interceptor for AgentHostRelayTokenInterceptor {
116 fn call(
117 &mut self,
118 mut request: tonic::Request<()>,
119 ) -> std::result::Result<tonic::Request<()>, tonic::Status> {
120 if let Some(token) = self.token.clone() {
121 request
122 .metadata_mut()
123 .insert(AGENT_HOST_RELAY_TOKEN_HEADER, token);
124 }
125 Ok(request)
126 }
127}
128
129fn agent_host_relay_token_interceptor(
130 token: &str,
131) -> std::result::Result<AgentHostRelayTokenInterceptor, AgentHostError> {
132 let trimmed = token.trim();
133 let token = if trimmed.is_empty() {
134 None
135 } else {
136 Some(MetadataValue::try_from(trimmed).map_err(|err| {
137 AgentHostError::Env(format!("agent host: invalid relay token metadata: {err}"))
138 })?)
139 };
140 Ok(AgentHostRelayTokenInterceptor { token })
141}
142
143enum AgentHostTarget {
144 Unix(String),
145 Tcp(String),
146 Tls(String),
147}
148
149fn parse_agent_host_target(raw: &str) -> std::result::Result<AgentHostTarget, AgentHostError> {
150 let target = raw.trim();
151 if target.is_empty() {
152 return Err(AgentHostError::Env(
153 "agent host: transport target is required".to_string(),
154 ));
155 }
156 if let Some(address) = target.strip_prefix("tcp://") {
157 let address = address.trim();
158 if address.is_empty() {
159 return Err(AgentHostError::Env(format!(
160 "agent host: tcp target {raw:?} is missing host:port"
161 )));
162 }
163 return Ok(AgentHostTarget::Tcp(address.to_string()));
164 }
165 if let Some(address) = target.strip_prefix("tls://") {
166 let address = address.trim();
167 if address.is_empty() {
168 return Err(AgentHostError::Env(format!(
169 "agent host: tls target {raw:?} is missing host:port"
170 )));
171 }
172 return Ok(AgentHostTarget::Tls(address.to_string()));
173 }
174 if let Some(path) = target.strip_prefix("unix://") {
175 let path = path.trim();
176 if path.is_empty() {
177 return Err(AgentHostError::Env(format!(
178 "agent host: unix target {raw:?} is missing a socket path"
179 )));
180 }
181 return Ok(AgentHostTarget::Unix(path.to_string()));
182 }
183 if target.contains("://") {
184 return Err(AgentHostError::Env(format!(
185 "agent host: unsupported target scheme in {raw:?}"
186 )));
187 }
188 Ok(AgentHostTarget::Unix(target.to_string()))
189}
190
191#[async_trait]
192pub trait AgentProvider: pb::agent_provider_server::AgentProvider + Send + Sync + 'static {
194 async fn configure(
196 &self,
197 _name: &str,
198 _config: serde_json::Map<String, serde_json::Value>,
199 ) -> ProviderResult<()> {
200 Ok(())
201 }
202
203 fn metadata(&self) -> Option<RuntimeMetadata> {
205 None
206 }
207
208 fn warnings(&self) -> Vec<String> {
210 Vec::new()
211 }
212
213 async fn health_check(&self) -> ProviderResult<()> {
215 Ok(())
216 }
217
218 async fn start(&self) -> ProviderResult<()> {
220 Ok(())
221 }
222
223 async fn close(&self) -> ProviderResult<()> {
225 Ok(())
226 }
227}
228
229#[derive(Clone)]
230pub(crate) struct AgentServer<P> {
231 provider: Arc<P>,
232}
233
234impl<P> AgentServer<P> {
235 pub(crate) fn new(provider: Arc<P>) -> Self {
236 Self { provider }
237 }
238}
239
240#[async_trait]
241impl<P> pb::agent_provider_server::AgentProvider for AgentServer<P>
242where
243 P: AgentProvider,
244{
245 async fn create_session(
246 &self,
247 request: GrpcRequest<pb::CreateAgentProviderSessionRequest>,
248 ) -> std::result::Result<GrpcResponse<pb::AgentSession>, Status> {
249 self.provider.create_session(request).await
250 }
251
252 async fn get_session(
253 &self,
254 request: GrpcRequest<pb::GetAgentProviderSessionRequest>,
255 ) -> std::result::Result<GrpcResponse<pb::AgentSession>, Status> {
256 self.provider.get_session(request).await
257 }
258
259 async fn list_sessions(
260 &self,
261 request: GrpcRequest<pb::ListAgentProviderSessionsRequest>,
262 ) -> std::result::Result<GrpcResponse<pb::ListAgentProviderSessionsResponse>, Status> {
263 self.provider.list_sessions(request).await
264 }
265
266 async fn update_session(
267 &self,
268 request: GrpcRequest<pb::UpdateAgentProviderSessionRequest>,
269 ) -> std::result::Result<GrpcResponse<pb::AgentSession>, Status> {
270 self.provider.update_session(request).await
271 }
272
273 async fn create_turn(
274 &self,
275 request: GrpcRequest<pb::CreateAgentProviderTurnRequest>,
276 ) -> std::result::Result<GrpcResponse<pb::AgentTurn>, Status> {
277 self.provider.create_turn(request).await
278 }
279
280 async fn get_turn(
281 &self,
282 request: GrpcRequest<pb::GetAgentProviderTurnRequest>,
283 ) -> std::result::Result<GrpcResponse<pb::AgentTurn>, Status> {
284 self.provider.get_turn(request).await
285 }
286
287 async fn list_turns(
288 &self,
289 request: GrpcRequest<pb::ListAgentProviderTurnsRequest>,
290 ) -> std::result::Result<GrpcResponse<pb::ListAgentProviderTurnsResponse>, Status> {
291 self.provider.list_turns(request).await
292 }
293
294 async fn cancel_turn(
295 &self,
296 request: GrpcRequest<pb::CancelAgentProviderTurnRequest>,
297 ) -> std::result::Result<GrpcResponse<pb::AgentTurn>, Status> {
298 self.provider.cancel_turn(request).await
299 }
300
301 async fn list_turn_events(
302 &self,
303 request: GrpcRequest<pb::ListAgentProviderTurnEventsRequest>,
304 ) -> std::result::Result<GrpcResponse<pb::ListAgentProviderTurnEventsResponse>, Status> {
305 self.provider.list_turn_events(request).await
306 }
307
308 async fn get_interaction(
309 &self,
310 request: GrpcRequest<pb::GetAgentProviderInteractionRequest>,
311 ) -> std::result::Result<GrpcResponse<pb::AgentInteraction>, Status> {
312 self.provider.get_interaction(request).await
313 }
314
315 async fn list_interactions(
316 &self,
317 request: GrpcRequest<pb::ListAgentProviderInteractionsRequest>,
318 ) -> std::result::Result<GrpcResponse<pb::ListAgentProviderInteractionsResponse>, Status> {
319 self.provider.list_interactions(request).await
320 }
321
322 async fn resolve_interaction(
323 &self,
324 request: GrpcRequest<pb::ResolveAgentProviderInteractionRequest>,
325 ) -> std::result::Result<GrpcResponse<pb::AgentInteraction>, Status> {
326 self.provider.resolve_interaction(request).await
327 }
328
329 async fn get_capabilities(
330 &self,
331 request: GrpcRequest<pb::GetAgentProviderCapabilitiesRequest>,
332 ) -> std::result::Result<GrpcResponse<pb::AgentProviderCapabilities>, Status> {
333 self.provider.get_capabilities(request).await
334 }
335}