// Auto-generated by BoltFFI. Do not edit.
@file:Suppress("unused", "RedundantVisibilityModifier", "MemberVisibilityCanBePrivate", "PropertyName", "FunctionName", "ClassName", "LocalVariableName", "SpellCheckingInspection", "NOTHING_TO_INLINE", "KotlinRedundantDiagnosticSuppress")
package {{ package_name }}
import java.nio.ByteBuffer
import java.nio.ByteOrder
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.atomic.AtomicBoolean
import java.util.concurrent.atomic.AtomicLong
import kotlin.coroutines.Continuation
import kotlin.coroutines.resume
import kotlin.coroutines.resumeWithException
import kotlinx.coroutines.CancellableContinuation
import kotlinx.coroutines.suspendCancellableCoroutine
import kotlin.jvm.JvmInline
{%- for import in extra_imports %}
import {{ import }}
{%- endfor %}
{%- for custom in custom_types %}
{%- if !custom.has_native_mapping %}
typealias {{ custom.class_name }} = {{ custom.repr_kotlin_type }}
{%- else %}
{% match custom.native_type %}
{% when Some with (native_type) %}
object {{ custom.class_name }} {
internal fun decode(reader: WireReader): {{ native_type }} {
return {% match custom.native_decode_expr %}{% when Some with (native_decode_expr) %}{{ native_decode_expr }}{% when None %}{{ custom.repr_decode_expr }}{% endmatch %}
}
}
internal fun {{ native_type }}.wireEncodedSize(): Int = run {
val repr = {% match custom.native_encode_expr %}{% when Some with (native_encode_expr) %}{{ native_encode_expr }}{% when None %}this{% endmatch %}
{{ custom.repr_size_expr }}
}
internal fun {{ native_type }}.wireEncodeTo(wire: WireWriter) {
val repr = {% match custom.native_encode_expr %}{% when Some with (native_encode_expr) %}{{ native_encode_expr }}{% when None %}this{% endmatch %}
{{ custom.repr_encode_expr }}
}
{% when None %}
object {{ custom.class_name }} {
internal fun decode(reader: WireReader): {{ custom.repr_kotlin_type }} {
return {{ custom.repr_decode_expr }}
}
}
internal fun {{ custom.repr_kotlin_type }}.wireEncodedSize(): Int = run {
val repr = this
{{ custom.repr_size_expr }}
}
internal fun {{ custom.repr_kotlin_type }}.wireEncodeTo(wire: WireWriter) {
val repr = this
{{ custom.repr_encode_expr }}
}
{% endmatch %}
{%- endif %}
{%- endfor %}
class FfiException(val code: Int, message: String) : Exception(message)
private fun takeLastErrorMessage(): String =
Native.{{ prefix }}_last_error_message().toString(Charsets.UTF_8)
object BoltFFIScope : CoroutineScope {
override val coroutineContext = Dispatchers.Default + SupervisorJob()
fun launchOnMain(block: suspend CoroutineScope.() -> Unit): Job =
launch(Dispatchers.Default, block = block)
}
class BoltFFIException(val errorBuffer: ByteBuffer) : Exception("Structured error") {
init {
errorBuffer.order(ByteOrder.nativeOrder())
}
}
private inline fun <T> useWireBytes(bytes: ByteArray, block: (java.nio.ByteBuffer) -> T): T {
return block(java.nio.ByteBuffer.wrap(bytes).order(java.nio.ByteOrder.LITTLE_ENDIAN))
}
private const val BOLTFFI_FUTURE_POLL_READY: Byte = 0
private const val BOLTFFI_FUTURE_POLL_WAKE: Byte = 1
internal class BoltFFIHandleMap<T: Any> {
private val map = ConcurrentHashMap<Long, T>()
private val counter = AtomicLong(1)
fun insert(obj: T): Long {
val handle = counter.getAndAdd(2)
map[handle] = obj
return handle
}
fun remove(handle: Long): T =
map.remove(handle) ?: throw IllegalStateException("BoltFFIHandleMap: invalid handle $handle")
fun tryRemove(handle: Long): T? = map.remove(handle)
fun get(handle: Long): T =
map[handle] ?: throw IllegalStateException("BoltFFIHandleMap: invalid handle $handle")
}
private val boltffiContinuationMap = BoltFFIHandleMap<CancellableContinuation<Byte>>()
internal suspend inline fun <T> boltffiCallAsync(
crossinline createFuture: () -> Long,
crossinline poll: (Long, Long) -> Unit,
crossinline complete: (Long) -> T,
crossinline free: (Long) -> Unit,
crossinline cancel: (Long) -> Unit
): T {
val rustFuture = createFuture()
try {
var pollResult: Byte
do {
pollResult = suspendCancellableCoroutine<Byte> { continuation ->
continuation.invokeOnCancellation { cancel(rustFuture) }
poll(rustFuture, boltffiContinuationMap.insert(continuation))
}
} while (pollResult != BOLTFFI_FUTURE_POLL_READY)
return complete(rustFuture)
} finally {
free(rustFuture)
}
}
{%- if has_streams %}
typealias SubscriptionHandle = Long
private const val BOLTFFI_STREAM_POLL_CLOSED: Byte = 1
// Coordinates the lifecycle of a native stream subscription.
//
// Two atomic tags prevent races between the poll callback (fired from native
// code on any thread) and user-initiated cancellation:
//
// lifecycleTag: 0 = running, 1 = unsubscribe called, 2 = unsubscribed, 3 = freed.
// Transitions are one-way. Once it leaves 0, no more polls register.
//
// callbackTag: 0 = idle, 1 = processing batches in handlePoll.
// Prevents a second poll callback from entering handlePoll while
// the first is still draining popBatch.
//
// attemptFinalize() only proceeds when both tags are in their final positions:
// callbackTag == 0 (no active processing) and lifecycleTag == 2 (unsubscribed).
// It then transitions lifecycleTag to 3 and calls freeFn + finish.
internal class BoltFFIStreamContext(
private val scope: CoroutineScope,
private val subscription: SubscriptionHandle,
private val batchSize: Long,
private val popBatch: (SubscriptionHandle, Long) -> ByteArray?,
private val poll: (SubscriptionHandle, Long) -> Unit,
private val unsubscribe: (SubscriptionHandle) -> Unit,
private val freeFn: (SubscriptionHandle) -> Unit,
private val processItems: (WireReader) -> Unit,
private val finish: () -> Unit
) {
// 0 = running, 1 = unsubscribe called, 2 = unsubscribed, 3 = freed
private val lifecycleTag = AtomicInteger(0)
// 0 = idle, 1 = processing. guards handlePoll against concurrent entry.
private val callbackTag = AtomicInteger(0)
fun start() {
registerPoll()
}
fun requestTermination() {
val started = lifecycleTag.compareAndSet(0, 1)
if (started) {
unsubscribe(subscription)
lifecycleTag.compareAndSet(1, 2)
}
attemptFinalize()
}
private fun attemptFinalize() {
if (!callbackTag.compareAndSet(0, 0)) return
if (!lifecycleTag.compareAndSet(2, 3)) return
freeFn(subscription)
finish()
}
private fun schedulePoll() {
scope.launch { registerPoll() }
}
private fun registerPoll() {
if (!lifecycleTag.compareAndSet(0, 0)) {
attemptFinalize()
return
}
scope.launch {
val pollResult = suspendCancellableCoroutine<Byte> { continuation ->
poll(subscription, boltffiContinuationMap.insert(continuation))
}
handlePoll(pollResult)
}
}
private fun handlePoll(pollResult: Byte) {
val isClosed = pollResult == BOLTFFI_STREAM_POLL_CLOSED
if (!callbackTag.compareAndSet(0, 1)) {
attemptFinalize()
return
}
try {
if (!lifecycleTag.compareAndSet(0, 0)) return
while (true) {
val bytes = popBatch(subscription, batchSize)
?: throw RuntimeException("BoltFFI: stream pop_batch failed (null)")
if (bytes.isEmpty()) break
val reader = WireReader(bytes)
processItems(reader)
}
} finally {
callbackTag.compareAndSet(1, 0)
attemptFinalize()
}
if (isClosed) {
requestTermination()
return
}
if (!lifecycleTag.compareAndSet(0, 0)) return
schedulePoll()
}
}
{%- endif %}
interface WireCodable {
fun encode(writer: WireWriter)
interface Decoder<T> {
fun decode(reader: WireReader): T
}
}
class WireReader {
private val array: ByteArray?
private val buffer: ByteBuffer?
private var pos: Int = 0
constructor(array: ByteArray) {
this.array = array
this.buffer = null
}
constructor(buffer: ByteBuffer) {
this.array = null
this.buffer = buffer.slice().order(ByteOrder.LITTLE_ENDIAN)
}
private fun byteAt(index: Int): Byte = array?.get(index) ?: buffer!!.get(index)
private fun copyRange(start: Int, len: Int): ByteArray {
val localArray = array
if (localArray != null) {
return localArray.copyOfRange(start, start + len)
}
val bytes = ByteArray(len)
if (len == 0) return bytes
val source = buffer!!.duplicate().order(ByteOrder.LITTLE_ENDIAN)
source.position(start)
source.limit(start + len)
source.get(bytes)
return bytes
}
@PublishedApi internal fun leI16(o: Int): Short {
val localArray = array
return if (localArray != null) {
((localArray[o].toInt() and 0xFF) or ((localArray[o + 1].toInt() and 0xFF) shl 8)).toShort()
} else {
buffer!!.getShort(o)
}
}
@PublishedApi internal fun leI32(o: Int): Int {
val localArray = array
return if (localArray != null) {
(localArray[o].toInt() and 0xFF) or
((localArray[o + 1].toInt() and 0xFF) shl 8) or
((localArray[o + 2].toInt() and 0xFF) shl 16) or
((localArray[o + 3].toInt() and 0xFF) shl 24)
} else {
buffer!!.getInt(o)
}
}
@PublishedApi internal fun leI64(o: Int): Long {
val localArray = array
return if (localArray != null) {
(localArray[o].toLong() and 0xFFL) or
((localArray[o + 1].toLong() and 0xFFL) shl 8) or
((localArray[o + 2].toLong() and 0xFFL) shl 16) or
((localArray[o + 3].toLong() and 0xFFL) shl 24) or
((localArray[o + 4].toLong() and 0xFFL) shl 32) or
((localArray[o + 5].toLong() and 0xFFL) shl 40) or
((localArray[o + 6].toLong() and 0xFFL) shl 48) or
((localArray[o + 7].toLong() and 0xFFL) shl 56)
} else {
buffer!!.getLong(o)
}
}
fun readBool(): Boolean { val v = byteAt(pos) != 0.toByte(); pos += 1; return v }
fun readI8(): Byte { val v = byteAt(pos); pos += 1; return v }
fun readU8(): UByte { val v = byteAt(pos).toUByte(); pos += 1; return v }
fun readI16(): Short { val v = leI16(pos); pos += 2; return v }
fun readU16(): UShort { val v = leI16(pos).toUShort(); pos += 2; return v }
fun readI32(): Int { val v = leI32(pos); pos += 4; return v }
fun readU32(): UInt { val v = leI32(pos).toUInt(); pos += 4; return v }
fun readI64(): Long { val v = leI64(pos); pos += 8; return v }
fun readU64(): ULong { val v = leI64(pos).toULong(); pos += 8; return v }
fun readF32(): Float { val v = Float.fromBits(leI32(pos)); pos += 4; return v }
fun readF64(): Double { val v = Double.fromBits(leI64(pos)); pos += 8; return v }
fun readString(): String {
val len = leI32(pos).toUInt().toInt()
pos += 4
if (len == 0) return ""
val start = pos
pos += len
val localArray = array
return if (localArray != null) {
String(localArray, start, len, Charsets.UTF_8)
} else {
String(copyRange(start, len), Charsets.UTF_8)
}
}
fun readBytes(): ByteArray {
val len = leI32(pos).toUInt().toInt()
pos += 4
if (len == 0) return ByteArray(0)
val start = pos
pos += len
return copyRange(start, len)
}
fun readDuration(): Duration {
val seconds = readI64()
val nanos = readI32().toLong()
require(seconds >= 0L) { "Duration out of range" }
require(nanos >= 0L) { "Duration nanos out of range" }
return Duration.ofSeconds(seconds, nanos)
}
fun readInstant(): Instant {
val seconds = readI64()
val nanos = readI32().toLong()
require(nanos >= 0L) { "Instant nanos out of range" }
if (seconds >= 0L) {
return Instant.EPOCH.plus(Duration.ofSeconds(seconds, nanos))
}
require(seconds != Long.MIN_VALUE) { "Instant out of range" }
return Instant.EPOCH.minus(Duration.ofSeconds(-seconds, nanos))
}
fun readUuid(): UUID = UUID(readI64(), readI64())
fun readUri(): URI = URI.create(readString())
inline fun <T> readOptional(reader: (WireReader) -> T): T? {
val tag = readU8().toInt()
return if (tag == 0) null else reader(this)
}
inline fun <T> readList(reader: (WireReader) -> T): List<T> {
val count = readI32().toUInt().toInt()
if (count == 0) return emptyList()
return ArrayList<T>(count).also { list -> repeat(count) { list.add(reader(this)) } }
}
inline fun <T, E> readResult(okReader: (WireReader) -> T, errReader: (WireReader) -> E): BoltFFIResult<T, E> {
val tag = readU8().toInt()
return if (tag == 0) BoltFFIResult.Ok(okReader(this)) else BoltFFIResult.Err(errReader(this))
}
fun readShortArray(): ShortArray {
val count = readI32().toUInt().toInt()
if (count == 0) return ShortArray(0)
val base = pos
pos += count * 2
val localArray = array
return if (localArray != null) {
ShortArray(count) { i -> leI16(base + i * 2) }
} else {
val localBuffer = buffer!!
ShortArray(count) { i -> localBuffer.getShort(base + i * 2) }
}
}
fun readIntArray(): IntArray {
val count = readI32().toUInt().toInt()
if (count == 0) return IntArray(0)
val base = pos
pos += count * 4
val localArray = array
return if (localArray != null) {
IntArray(count) { i -> leI32(base + i * 4) }
} else {
val localBuffer = buffer!!
IntArray(count) { i -> localBuffer.getInt(base + i * 4) }
}
}
fun readLongArray(): LongArray {
val count = readI32().toUInt().toInt()
if (count == 0) return LongArray(0)
val base = pos
pos += count * 8
val localArray = array
return if (localArray != null) {
LongArray(count) { i -> leI64(base + i * 8) }
} else {
val localBuffer = buffer!!
LongArray(count) { i -> localBuffer.getLong(base + i * 8) }
}
}
fun readFloatArray(): FloatArray {
val count = readI32().toUInt().toInt()
if (count == 0) return FloatArray(0)
val base = pos
pos += count * 4
val localArray = array
return if (localArray != null) {
FloatArray(count) { i -> Float.fromBits(leI32(base + i * 4)) }
} else {
val localBuffer = buffer!!
FloatArray(count) { i -> localBuffer.getFloat(base + i * 4) }
}
}
fun readDoubleArray(): DoubleArray {
val count = readI32().toUInt().toInt()
if (count == 0) return DoubleArray(0)
val base = pos
pos += count * 8
val localArray = array
return if (localArray != null) {
DoubleArray(count) { i -> Double.fromBits(leI64(base + i * 8)) }
} else {
val localBuffer = buffer!!
DoubleArray(count) { i -> localBuffer.getDouble(base + i * 8) }
}
}
fun readBooleanArray(): BooleanArray {
val count = readI32().toUInt().toInt()
if (count == 0) return BooleanArray(0)
val base = pos
pos += count
val localArray = array
return if (localArray != null) {
BooleanArray(count) { i -> localArray[base + i] != 0.toByte() }
} else {
val localBuffer = buffer!!
BooleanArray(count) { i -> localBuffer.get(base + i) != 0.toByte() }
}
}
fun skip(count: Int) { pos += count }
fun readBlittable(sizeBytes: Int): ByteArray {
val start = pos
pos += sizeBytes
return copyRange(start, sizeBytes)
}
}
sealed class BoltFFIResult<out T, out E> {
data class Ok<T>(val value: T) : BoltFFIResult<T, Nothing>()
data class Err<E>(val error: E) : BoltFFIResult<Nothing, E>()
val isSuccess: Boolean get() = this is Ok
val isFailure: Boolean get() = this is Err
fun getOrThrow(): T = when (this) {
is Ok -> value
is Err -> throw when (error) {
is Throwable -> error
else -> FfiException(-1, error.toString())
}
}
fun getOrNull(): T? = when (this) {
is Ok -> value
is Err -> null
}
fun exceptionOrNull(): Throwable? = when (this) {
is Ok -> null
is Err -> when (error) {
is Throwable -> error
else -> FfiException(-1, error.toString())
}
}
inline fun <R> fold(onSuccess: (T) -> R, onFailure: (E) -> R): R = when (this) {
is Ok -> onSuccess(value)
is Err -> onFailure(error)
}
}
@Suppress("UNCHECKED_CAST")
internal fun <T> boltffiUnsafeCast(value: Any?): T = value as T
private object Utf8Codec {
fun maxBytes(value: String): Int = value.length * 3
}
class WireWriter(initialCapacity: Int = 256) {
private var buffer: ByteBuffer = ByteBuffer.allocateDirect(initialCapacity).order(ByteOrder.LITTLE_ENDIAN)
private var pos: Int = 0
internal fun reset(requiredCapacity: Int) {
if (buffer.capacity() < requiredCapacity) {
buffer = ByteBuffer.allocateDirect(requiredCapacity).order(ByteOrder.LITTLE_ENDIAN)
}
pos = 0
}
internal fun asDirectBuffer(): ByteBuffer {
val view = buffer.duplicate().order(ByteOrder.LITTLE_ENDIAN)
view.limit(pos)
view.position(0)
return view.slice().order(ByteOrder.LITTLE_ENDIAN)
}
internal fun toByteArray(): ByteArray {
val arr = ByteArray(pos)
val view = buffer.duplicate()
view.position(0)
view.get(arr, 0, pos)
return arr
}
private fun ensureCapacity(needed: Int) {
val required = pos + needed
if (required <= buffer.capacity()) return
val nextCapacity = maxOf(buffer.capacity() * 2, required)
val next = ByteBuffer.allocateDirect(nextCapacity).order(ByteOrder.LITTLE_ENDIAN)
val source = buffer.duplicate().order(ByteOrder.LITTLE_ENDIAN)
source.limit(pos)
source.position(0)
next.put(source)
buffer = next
}
internal inline fun writeRawBytes(byteCount: Int, writer: (ByteBuffer, Int) -> Unit) {
ensureCapacity(byteCount)
val baseOffset = pos
writer(buffer, baseOffset)
pos = baseOffset + byteCount
}
fun writeBool(v: Boolean) { ensureCapacity(1); buffer.put(pos, if (v) 1 else 0); pos += 1 }
fun writePadding(count: Int) {
ensureCapacity(count)
val view = buffer.duplicate().order(ByteOrder.LITTLE_ENDIAN)
view.position(pos)
repeat(count) { view.put(0) }
pos += count
}
fun writeI8(v: Byte) { ensureCapacity(1); buffer.put(pos, v); pos += 1 }
fun writeU8(v: UByte) { ensureCapacity(1); buffer.put(pos, v.toByte()); pos += 1 }
fun writeI16(v: Short) { ensureCapacity(2); buffer.putShort(pos, v); pos += 2 }
fun writeU16(v: UShort) = writeI16(v.toShort())
fun writeI32(v: Int) { ensureCapacity(4); buffer.putInt(pos, v); pos += 4 }
fun writeU32(v: UInt) = writeI32(v.toInt())
fun writeI64(v: Long) { ensureCapacity(8); buffer.putLong(pos, v); pos += 8 }
fun writeU64(v: ULong) = writeI64(v.toLong())
fun writeF32(v: Float) = writeI32(java.lang.Float.floatToRawIntBits(v))
fun writeF64(v: Double) = writeI64(java.lang.Double.doubleToRawLongBits(v))
fun writeDuration(v: Duration) {
require(v.seconds >= 0L) { "Invalid duration, must be non-negative" }
require(v.nano >= 0) { "Invalid duration nanos" }
writeI64(v.seconds)
writeI32(v.nano)
}
fun writeInstant(v: Instant) {
var epochOffset = Duration.between(Instant.EPOCH, v)
var sign = 1L
if (epochOffset.isNegative) {
sign = -1L
epochOffset = epochOffset.negated()
}
require(epochOffset.nano >= 0) { "Invalid instant nanos" }
writeI64(sign * epochOffset.seconds)
writeI32(epochOffset.nano)
}
fun writeUuid(v: UUID) {
writeI64(v.mostSignificantBits)
writeI64(v.leastSignificantBits)
}
fun writeUri(v: URI) {
writeString(v.toString())
}
fun writeString(v: String) {
val bytes = v.toByteArray(Charsets.UTF_8)
writeU32(bytes.size.toUInt())
ensureCapacity(bytes.size)
val view = buffer.duplicate().order(ByteOrder.LITTLE_ENDIAN)
view.position(pos)
view.put(bytes)
pos += bytes.size
}
fun writeBytes(v: ByteArray) {
writeU32(v.size.toUInt())
ensureCapacity(v.size)
val view = buffer.duplicate().order(ByteOrder.LITTLE_ENDIAN)
view.position(pos)
view.put(v)
pos += v.size
}
fun writePrimitiveList(v: IntArray) {
writeU32(v.size.toUInt())
val bytes = v.size * 4
writeRawBytes(bytes) { buf, baseOffset ->
val view = buf.duplicate().order(ByteOrder.LITTLE_ENDIAN)
view.position(baseOffset)
view.limit(baseOffset + bytes)
view.slice().order(ByteOrder.LITTLE_ENDIAN).asIntBuffer().put(v)
}
}
fun writePrimitiveList(v: LongArray) {
writeU32(v.size.toUInt())
val bytes = v.size * 8
writeRawBytes(bytes) { buf, baseOffset ->
val view = buf.duplicate().order(ByteOrder.LITTLE_ENDIAN)
view.position(baseOffset)
view.limit(baseOffset + bytes)
view.slice().order(ByteOrder.LITTLE_ENDIAN).asLongBuffer().put(v)
}
}
fun writePrimitiveList(v: FloatArray) {
writeU32(v.size.toUInt())
val bytes = v.size * 4
writeRawBytes(bytes) { buf, baseOffset ->
val view = buf.duplicate().order(ByteOrder.LITTLE_ENDIAN)
view.position(baseOffset)
view.limit(baseOffset + bytes)
view.slice().order(ByteOrder.LITTLE_ENDIAN).asFloatBuffer().put(v)
}
}
fun writePrimitiveList(v: DoubleArray) {
writeU32(v.size.toUInt())
val bytes = v.size * 8
writeRawBytes(bytes) { buf, baseOffset ->
val view = buf.duplicate().order(ByteOrder.LITTLE_ENDIAN)
view.position(baseOffset)
view.limit(baseOffset + bytes)
view.slice().order(ByteOrder.LITTLE_ENDIAN).asDoubleBuffer().put(v)
}
}
fun writePrimitiveList(v: ShortArray) {
writeU32(v.size.toUInt())
val bytes = v.size * 2
writeRawBytes(bytes) { buf, baseOffset ->
val view = buf.duplicate().order(ByteOrder.LITTLE_ENDIAN)
view.position(baseOffset)
view.limit(baseOffset + bytes)
view.slice().order(ByteOrder.LITTLE_ENDIAN).asShortBuffer().put(v)
}
}
fun writePrimitiveList(v: ByteArray) {
writeU32(v.size.toUInt())
ensureCapacity(v.size)
val view = buffer.duplicate().order(ByteOrder.LITTLE_ENDIAN)
view.position(pos)
view.put(v)
pos += v.size
}
fun writePrimitiveList(v: BooleanArray) {
writeU32(v.size.toUInt())
v.forEach { writeBool(it) }
}
@JvmName("writeIntList")
fun writePrimitiveList(v: List<Int>) {
writeU32(v.size.toUInt())
v.forEach { writeI32(it) }
}
@JvmName("writeLongList")
fun writePrimitiveList(v: List<Long>) {
writeU32(v.size.toUInt())
v.forEach { writeI64(it) }
}
inline fun <reified T> writeBlittable(v: T) {
when (v) {
is Byte -> writeI8(v)
is Short -> writeI16(v)
is Int -> writeI32(v)
is Long -> writeI64(v)
is Float -> writeF32(v)
is Double -> writeF64(v)
is Boolean -> writeBool(v)
else -> throw IllegalArgumentException("Cannot write blittable: ${T::class}")
}
}
inline fun <reified T> writeBlittableList(v: List<T>) {
writeU32(v.size.toUInt())
v.forEach { item -> writeBlittable(item) }
}
}
private const val MAX_CACHED_WIRE_WRITER_BYTES: Int = 1024 * 1024
internal class WireWriterPoolState(private val cacheSize: Int = 4) {
private val cachedWriters: Array<WireWriter?> = arrayOfNulls(cacheSize)
private var depth: Int = 0
fun acquire(requiredCapacity: Int): BorrowedWireWriter {
val slot = depth
depth = slot + 1
val shouldCache = requiredCapacity <= MAX_CACHED_WIRE_WRITER_BYTES && slot < cacheSize
val writer = if (shouldCache) {
cachedWriters[slot] ?: WireWriter(requiredCapacity).also { cachedWriters[slot] = it }
} else {
WireWriter(requiredCapacity)
}
writer.reset(requiredCapacity)
return BorrowedWireWriter(this, writer)
}
fun release() {
depth -= 1
}
}
internal class BorrowedWireWriter(
private val state: WireWriterPoolState,
internal val writer: WireWriter
) : AutoCloseable {
internal val buffer: ByteBuffer
get() = writer.asDirectBuffer()
override fun close() {
state.release()
}
}
internal object WireWriterPool {
private val state: ThreadLocal<WireWriterPoolState> = ThreadLocal.withInitial { WireWriterPoolState() }
fun acquire(requiredCapacity: Int): BorrowedWireWriter {
val poolState = state.get() ?: WireWriterPoolState().also { state.set(it) }
return poolState.acquire(requiredCapacity)
}
}